diff --git a/xla/service/all_reduce_reassociate.cc b/xla/service/all_reduce_reassociate.cc index 6063eef7b6e6b..56b4fe257d4ea 100644 --- a/xla/service/all_reduce_reassociate.cc +++ b/xla/service/all_reduce_reassociate.cc @@ -86,7 +86,7 @@ HloInstruction* LookThroughForAllReduce(HloInstruction* instr, const Literal& reduction_identity) { // Match reduce-scatter pattern. Support only the non-formatted case at the // moment. - if (instr->opcode() == HloOpcode::kDynamicSlice) { + if (HloPredicateIsOp(instr)) { // Dynamic-slice to be matched needs to be immediately using an AllReduce. if (instr->operand(0)->opcode() != HloOpcode::kAllReduce || instr->operand(0)->user_count() != 1 || instr->user_count() != 1) { @@ -94,17 +94,15 @@ HloInstruction* LookThroughForAllReduce(HloInstruction* instr, } return instr; } - while (instr->opcode() != HloOpcode::kAllReduce) { + while (HloPredicateIsNotOp(instr)) { if (instr->user_count() != 1) { return nullptr; } - if (instr->opcode() != HloOpcode::kReshape && - instr->opcode() != HloOpcode::kPad && - instr->opcode() != HloOpcode::kSlice && - instr->opcode() != HloOpcode::kConvert) { + if (HloPredicateIsNotOp(instr)) { return nullptr; } - if (instr->opcode() == HloOpcode::kPad) { + if (HloPredicateIsOp(instr)) { if (!instr->operand(1)->IsConstant()) { return nullptr; } @@ -223,7 +221,7 @@ absl::StatusOr AllReduceReassociate::Run( continue; } if (lhs->opcode() != rhs->opcode() || - (lhs->opcode() == HloOpcode::kDynamicSlice && + (HloPredicateIsOp(lhs) && !ShapeUtil::Compatible(lhs->operand(0)->shape(), rhs->operand(0)->shape()))) { continue; @@ -232,7 +230,7 @@ absl::StatusOr AllReduceReassociate::Run( HloAllReduceInstruction* ar1 = nullptr; bool reduce_scatter_pattern_match = false; // Check Dynamic-slice pattern is identical - if (lhs->opcode() == HloOpcode::kDynamicSlice) { + if (HloPredicateIsOp(lhs)) { HloInstruction* original_rhs_operand = rhs->mutable_operand(0); TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, lhs->mutable_operand(0))); if (!lhs->Identical(*rhs)) { diff --git a/xla/service/all_reduce_simplifier.cc b/xla/service/all_reduce_simplifier.cc index c51492f0550cc..1a125d934bd83 100644 --- a/xla/service/all_reduce_simplifier.cc +++ b/xla/service/all_reduce_simplifier.cc @@ -81,8 +81,8 @@ absl::StatusOr AllReduceSimplifier::Run( for (auto computation : module->computations(execution_threads)) { for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { // AllGather and ReduceScatter with the same input and output shape - if ((inst->opcode() == HloOpcode::kAllGather || - inst->opcode() == HloOpcode::kReduceScatter) && + if ((HloPredicateIsOp( + inst)) && ShapeUtil::Compatible(inst->shape(), inst->operand(0)->shape())) { changed = true; TF_RETURN_IF_ERROR( diff --git a/xla/service/buffer_assignment.cc b/xla/service/buffer_assignment.cc index 12540a782a8db..1f73f5622f0b7 100644 --- a/xla/service/buffer_assignment.cc +++ b/xla/service/buffer_assignment.cc @@ -316,7 +316,7 @@ static const HloInstruction* GetEntryParameterInstruction( for (const auto& p : alloc.assigned_buffers()) { const HloValue* value = p.first; const HloInstruction* instr = value->instruction(); - if (instr->opcode() == HloOpcode::kParameter && + if (HloPredicateIsOp(instr) && instr->parent() == instr->GetModule()->entry_computation()) { return instr; } @@ -1047,7 +1047,7 @@ std::string BufferAssignment::ToVerboseString( buf_strs.push_back(absl::StrCat( "\n\t\tOperator: ", xla::OpMetadataToString(instr->metadata()))); } - if (instr->opcode() == HloOpcode::kParameter && + if (HloPredicateIsOp(instr) && (instr->parent() == instr->GetModule()->entry_computation())) { // Special case on entry parameters as they sometimes have hundreds of // indices in their shapes, and overwhelm the output. @@ -1483,7 +1483,7 @@ absl::Status BufferAssigner::AssignSingleHloBuffer( const HloInstruction* instruction = value->instruction(); const bool is_entry_parameter = - instruction->opcode() == HloOpcode::kParameter && + HloPredicateIsOp(instruction) && instruction->parent() == instruction->GetModule()->entry_computation(); if (is_entry_parameter) { diff --git a/xla/service/collective_ops_utils.cc b/xla/service/collective_ops_utils.cc index 8c0e1ee86c435..503c39405ffbf 100644 --- a/xla/service/collective_ops_utils.cc +++ b/xla/service/collective_ops_utils.cc @@ -167,11 +167,12 @@ absl::StatusOr GetCollectiveOpGroupMode( absl::StatusOr GetCollectiveUseGlobalDeviceIds( const HloInstruction* hlo) { - const bool is_all_reduce = (hlo->opcode() == HloOpcode::kAllReduce || - hlo->opcode() == HloOpcode::kAllReduceStart || - hlo->opcode() == HloOpcode::kReduceScatter); - const bool is_all_gather = (hlo->opcode() == HloOpcode::kAllGather || - hlo->opcode() == HloOpcode::kAllGatherStart); + const bool is_all_reduce = + (HloPredicateIsOp(hlo)); + const bool is_all_gather = + (HloPredicateIsOp( + hlo)); if (!is_all_reduce && !is_all_gather) { return absl::InvalidArgumentError( "GetReplicaGroupCountAndSize only supports AllReduce and AllGather."); @@ -745,7 +746,7 @@ bool IsCollective(const HloInstruction* instruction) { if (IsNonFusionCollective(instruction)) { return true; } - if (instruction->opcode() == HloOpcode::kFusion && + if (HloPredicateIsOp(instruction) && instruction->IsCustomFusion()) { for (const auto* inner_inst : instruction->fused_instructions()) { if (IsCollective(inner_inst)) { @@ -757,7 +758,7 @@ bool IsCollective(const HloInstruction* instruction) { } HloInstruction* IsOrHasCollectiveWithChannelId(HloInstruction* instruction) { - if (instruction->opcode() == HloOpcode::kFusion) { + if (HloPredicateIsOp(instruction)) { for (auto* inner_inst : instruction->fused_instructions()) { if (IsOrHasCollectiveWithChannelId(inner_inst) != nullptr) { return inner_inst; diff --git a/xla/service/conditional_code_motion_test.cc b/xla/service/conditional_code_motion_test.cc index 1398a9b1fdc8d..69820d5fb35d8 100644 --- a/xla/service/conditional_code_motion_test.cc +++ b/xla/service/conditional_code_motion_test.cc @@ -359,9 +359,9 @@ ENTRY main { for (int i = 0; i < on_false->root_instruction()->operand_count(); ++i) { const HloInstruction* root_operand = on_false->root_instruction()->operand(i); - if (root_operand->opcode() == HloOpcode::kAdd) { + if (HloPredicateIsOp(root_operand)) { on_false_add_idx = i; - } else if (root_operand->opcode() == HloOpcode::kSubtract) { + } else if (HloPredicateIsOp(root_operand)) { on_false_sub_idx = i; } } @@ -425,9 +425,9 @@ ENTRY main { for (int i = 0; i < on_false->root_instruction()->operand_count(); ++i) { const HloInstruction* root_operand = on_false->root_instruction()->operand(i); - if (root_operand->opcode() == HloOpcode::kGetTupleElement) { + if (HloPredicateIsOp(root_operand)) { on_false_gte_idx = i; - } else if (root_operand->opcode() == HloOpcode::kConstant) { + } else if (HloPredicateIsOp(root_operand)) { on_false_const_idx = i; } } @@ -501,9 +501,9 @@ ENTRY main { for (int i = 0; i < on_false->root_instruction()->operand_count(); ++i) { const HloInstruction* root_operand = on_false->root_instruction()->operand(i); - if (root_operand->opcode() == HloOpcode::kConstant) { + if (HloPredicateIsOp(root_operand)) { on_false_const_idx = i; - } else if (root_operand->opcode() == HloOpcode::kGetTupleElement) { + } else if (HloPredicateIsOp(root_operand)) { on_false_gte_idx = i; } } diff --git a/xla/service/conditional_simplifier.cc b/xla/service/conditional_simplifier.cc index edd01e32da60c..75bc69c2e4bea 100644 --- a/xla/service/conditional_simplifier.cc +++ b/xla/service/conditional_simplifier.cc @@ -77,7 +77,7 @@ absl::StatusOr TryRemoveUnusedConditionalOperands( for (HloInstruction* user : param->users()) { // If the user is not a get tuple element, assume it is unsafe to remove // elements from the tuple. - if (user->opcode() != HloOpcode::kGetTupleElement) { + if (HloPredicateIsNotOp(user)) { return false; } tuple_indices_to_keep.insert(user->tuple_index()); @@ -215,7 +215,7 @@ bool RemoveUnusedTupleElements(HloInstruction* conditional_op) { std::vector used_indices(old_tuple_shapes_size, false); for (const HloInstruction* user : conditional_op->users()) { // We only deal with the case where all users are GTE instructions. - if (user->opcode() != HloOpcode::kGetTupleElement) { + if (HloPredicateIsNotOp(user)) { VLOG(3) << "Skip RemoveUnusedTupleElements due to non-GTE user:\n" << user->ToShortString(); return false; @@ -363,7 +363,7 @@ bool MergeDuplicateTupleElements(HloInstruction* conditional) { } for (const HloInstruction* user : conditional->users()) { - if (user->opcode() != HloOpcode::kGetTupleElement) { + if (HloPredicateIsNotOp(user)) { VLOG(3) << "Skip MergeDuplicateTupleElements due not all users are " "kGetTupleElement:\n" << conditional->ToShortString(); @@ -614,7 +614,7 @@ absl::StatusOr ConditionalSimplifier::Run( std::vector conditional_ops; for (auto* comp : module->computations(execution_threads)) { for (auto* instr : comp->MakeInstructionPostOrder()) { - if (instr->opcode() == HloOpcode::kConditional) { + if (HloPredicateIsOp(instr)) { // Verifier wants a single send/recv with a given channel. This pass // clones computations which can result in that getting violated. if (InstructionCallsChannelInstructions(*instr)) { diff --git a/xla/service/dynamic_dimension_inference.cc b/xla/service/dynamic_dimension_inference.cc index 97436ca78d229..c7f77cb63a073 100644 --- a/xla/service/dynamic_dimension_inference.cc +++ b/xla/service/dynamic_dimension_inference.cc @@ -946,7 +946,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleSetDimensionSize( HloInstruction* hlo) { bool dimension_is_static = false; const HloInstruction* size = hlo->operand(1); - if (size->opcode() == HloOpcode::kConstant) { + if (HloPredicateIsOp(size)) { // Check if we are setting a dimension size to its static size. If so, // removes the dynamic dimension. // @@ -1332,7 +1332,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleReshape( return false; } VLOG(3) << "Found " << found_dims.size() << "\n"; - if (op->opcode() == HloOpcode::kReshape) { + if (HloPredicateIsOp(op)) { for (auto op_dim_index : found_dims) { auto orig_reshape_pair = find_reshape_group_pair(op, op_dim_index); if (is_reverse_reshape_group_pair(op, orig_reshape_pair, hlo, @@ -2259,7 +2259,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleWhile( // hlo->while_condition()->parameter_instruction(0); TF_ASSIGN_OR_RETURN(WhileUtil::MakeInstructionsLiveInResult result, WhileUtil::MakeInstructionsLiveIn(hlo, operands_to_add)); - TF_RET_CHECK(result.replacement_instr->opcode() == HloOpcode::kTuple); + TF_RET_CHECK(HloPredicateIsOp(result.replacement_instr)); // WhileUtil creates a new while hlo and tuple. Update the dynamic size // mapping for the newly created tuple. HloInstruction* new_tuple_operand = @@ -2367,7 +2367,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleWhile( TF_RET_CHECK(!index.empty()); HloInstruction* gte = result.replacement_instr->mutable_operand(index.front()); - TF_RET_CHECK(gte->opcode() == HloOpcode::kGetTupleElement); + TF_RET_CHECK(HloPredicateIsOp(gte)); TF_RET_CHECK(gte->operand(0) == hlo); ShapeUtil::GetMutableSubshape(gte->mutable_shape(), ShapeIndexView(index).subspan(1)) @@ -2432,17 +2432,14 @@ absl::StatusOr DynamicDimensionInferenceVisitor::RequiresPadToStatic( auto uses = dataflow_analysis_.GetValueDefinedAt(instr, shape_index).GetUses(); for (const auto& use : uses) { - if (use.instruction->opcode() == HloOpcode::kAsyncStart || - use.instruction->opcode() == HloOpcode::kAsyncUpdate || - use.instruction->opcode() == HloOpcode::kAsyncDone || - use.instruction->opcode() == HloOpcode::kCall || - use.instruction->opcode() == HloOpcode::kTuple || - use.instruction->opcode() == HloOpcode::kGetTupleElement || - use.instruction->opcode() == HloOpcode::kConditional) { + if (HloPredicateIsOp(use.instruction)) { // These uses do not require padding as they do not operate the data. continue; } - if (use.instruction->opcode() == HloOpcode::kWhile) { + if (HloPredicateIsOp(use.instruction)) { TF_RET_CHECK(use.operand_number == 0); HloInstruction* root = use.instruction->while_body()->root_instruction(); if (parent_->HasDynamicDimension(root, use.operand_index)) { @@ -2450,17 +2447,17 @@ absl::StatusOr DynamicDimensionInferenceVisitor::RequiresPadToStatic( } continue; } - if (use.instruction->opcode() == HloOpcode::kSetDimensionSize) { + if (HloPredicateIsOp(use.instruction)) { // The dynamic size cannot itself be dynamic. TF_RET_CHECK(use.operand_number == 0); // SetDimensionSize will be removed, so the array must be padded if it // is a user of the array. return true; } - if (use.instruction->opcode() == HloOpcode::kGetDimensionSize) { + if (HloPredicateIsOp(use.instruction)) { return true; } - if (use.instruction->opcode() != HloOpcode::kCustomCall || + if (HloPredicateIsNotOp(use.instruction) || !use.instruction->IsCustomCall({"PadToStatic", "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape"})) { @@ -2863,7 +2860,7 @@ bool DynamicDimensionInference::CanInfer(HloInstruction* hlo) { // However, if there are called computations, we may need to run inference on // them. Similarly, custom calls can do anything based on the user callbacks. if (hlo->shape().is_static() && hlo->called_computations().empty() && - hlo->opcode() != HloOpcode::kCustomCall) { + HloPredicateIsNotOp(hlo)) { return false; } // The dimensions of all operands must either be 1) not dynamic, or 2) have a diff --git a/xla/service/dynamic_padder_test.cc b/xla/service/dynamic_padder_test.cc index 29e3724bcd79c..3fc49508c053b 100644 --- a/xla/service/dynamic_padder_test.cc +++ b/xla/service/dynamic_padder_test.cc @@ -65,7 +65,7 @@ namespace m = ::xla::match; namespace op = xla::testing::opcode_matchers; OpDynamismSupport OpHasDynamismSupport(HloInstruction* hlo) { - if (hlo->opcode() != HloOpcode::kCustomCall) { + if (HloPredicateIsNotOp(hlo)) { return OpDynamismSupport::kNoSupport; } if (hlo->custom_call_target() == "OpWithDynamicLowering") { @@ -591,7 +591,7 @@ ENTRY main { HloInstruction* while_inst = nullptr; for (HloInstruction* inst : module_->entry_computation()->MakeInstructionPostOrder()) { - if (inst->opcode() == HloOpcode::kWhile) { + if (HloPredicateIsOp(inst)) { ASSERT_EQ(while_inst, nullptr) << "while_inst: " << while_inst->name() << ", inst: " << inst->name(); while_inst = inst; @@ -674,7 +674,7 @@ ENTRY main { module_ = GetHloModule(hlo_text); auto op_supports_dynamism = [](HloInstruction* hlo) { - if (hlo->opcode() != HloOpcode::kCustomCall) { + if (HloPredicateIsNotOp(hlo)) { return OpDynamismSupport::kNoSupport; } if (hlo->custom_call_target() == "ComputeActivations" || @@ -697,7 +697,7 @@ ENTRY main { for (HloComputation* computation : module_->computations()) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCustomCall) { + if (HloPredicateIsOp(instruction)) { EXPECT_NE(instruction->custom_call_target(), "PadToStatic"); EXPECT_NE(instruction->custom_call_target(), "SliceToDynamic"); if (instruction->custom_call_target() == "ComputeActivations") { @@ -705,7 +705,7 @@ ENTRY main { } else if (instruction->custom_call_target() == "ApplyGradients") { EXPECT_TRUE(instruction->operand(1)->shape().is_dynamic()); } - } else if (instruction->opcode() == HloOpcode::kWhile) { + } else if (HloPredicateIsOp(instruction)) { const Shape& shape = instruction->shape(); EXPECT_TRUE(shape.tuple_shapes(1).is_dynamic()); EXPECT_TRUE(shape.tuple_shapes(3).is_dynamic()); diff --git a/xla/service/gather_expander_test.cc b/xla/service/gather_expander_test.cc index a7f39c326336c..67bdf7b29d689 100644 --- a/xla/service/gather_expander_test.cc +++ b/xla/service/gather_expander_test.cc @@ -96,7 +96,7 @@ ENTRY main { HloInstruction* while_instr = nullptr; for (auto* instr : module->entry_computation()->instructions()) { - if (instr->opcode() == HloOpcode::kWhile) { + if (HloPredicateIsOp(instr)) { ASSERT_EQ(while_instr, nullptr) << "Expected exactly one while instruction in the entry computation " "after gather expansion"; @@ -159,7 +159,7 @@ ENTRY main { HloInstruction* while_instr = nullptr; for (auto* instr : module->entry_computation()->instructions()) { - if (instr->opcode() == HloOpcode::kWhile) { + if (HloPredicateIsOp(instr)) { ASSERT_EQ(while_instr, nullptr) << "Expected exactly one while instruction in the entry computation " "after gather expansion"; diff --git a/xla/service/hlo_domain_remover.cc b/xla/service/hlo_domain_remover.cc index 4f2d2efc6afc3..9f9a4bb3763eb 100644 --- a/xla/service/hlo_domain_remover.cc +++ b/xla/service/hlo_domain_remover.cc @@ -108,7 +108,7 @@ absl::StatusOr HloDomainRemover::RemoveExitDomains( // users vector could be changed during the loop(e.g. ReplaceAllUsesWith). const std::vector users(instruction->users()); for (HloInstruction* user : users) { - if (user->opcode() == HloOpcode::kDomain && + if (HloPredicateIsOp(user) && user->user_side_metadata().Kind() == domain_kind && user->operand_side_metadata().Kind() == domain_kind) { VLOG(5) << "Removing exit domain " << user->name(); diff --git a/xla/service/hlo_domain_verifier.cc b/xla/service/hlo_domain_verifier.cc index 519f572c3dfbb..776f6326c3f66 100644 --- a/xla/service/hlo_domain_verifier.cc +++ b/xla/service/hlo_domain_verifier.cc @@ -52,7 +52,7 @@ absl::Status HloDomainVerifier::RunContext::PopulateDomainKinds( for (HloComputation* computation : module_->computations(execution_threads)) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kDomain) { + if (HloPredicateIsOp(instruction)) { TF_RET_CHECK(instruction->user_side_metadata().Kind() == instruction->operand_side_metadata().Kind()) << instruction->ToString(); diff --git a/xla/service/latency_hiding_scheduler_test.cc b/xla/service/latency_hiding_scheduler_test.cc index 76e4fce0a9597..c39aa411a64b9 100644 --- a/xla/service/latency_hiding_scheduler_test.cc +++ b/xla/service/latency_hiding_scheduler_test.cc @@ -62,13 +62,13 @@ bool MaxConcurrentCollectivePermutesBelowThreshold( int max_concurrent_collective_permutes = 0; int num_concurrent_collective_permutes = 0; for (HloInstruction* instruction : instruction_sequence) { - if (instruction->opcode() == HloOpcode::kCollectivePermuteStart) { + if (HloPredicateIsOp(instruction)) { num_concurrent_collective_permutes += 1; max_concurrent_collective_permutes = std::max(max_concurrent_collective_permutes, num_concurrent_collective_permutes); } - if (instruction->opcode() == HloOpcode::kCollectivePermuteDone) { + if (HloPredicateIsOp(instruction)) { num_concurrent_collective_permutes -= 1; } } @@ -125,7 +125,8 @@ class TestLatencyEstimator : public LatencyEstimator { ? kMediumCost : kLowCost * ShapeUtil::ElementsIn(instr->shape()); } - if (instr->IsOutputFusion() || instr->opcode() == HloOpcode::kConvolution) { + if (instr->IsOutputFusion() || + HloPredicateIsOp(instr)) { return instr->shape().IsTuple() ? kHighCost : kMediumCost * ShapeUtil::ElementsIn(instr->shape()); diff --git a/xla/service/loop_schedule_linearizer.cc b/xla/service/loop_schedule_linearizer.cc index 3aa8067f8214b..e04d3b1b86be0 100644 --- a/xla/service/loop_schedule_linearizer.cc +++ b/xla/service/loop_schedule_linearizer.cc @@ -175,7 +175,7 @@ absl::StatusOr LoopScheduleLinearizer::Run( for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kWhile) { + if (HloPredicateIsNotOp(instruction)) { continue; } diff --git a/xla/service/multi_output_fusion.cc b/xla/service/multi_output_fusion.cc index 0967e152717da..c66e10202fef8 100644 --- a/xla/service/multi_output_fusion.cc +++ b/xla/service/multi_output_fusion.cc @@ -100,7 +100,7 @@ absl::StatusOr MultiOutputFusion::Run( continue; } if (instruction_id < user_id && - user->opcode() == HloOpcode::kFusion) { + HloPredicateIsOp(user)) { VLOG(10) << "User ID for user: " << user->name() << " is " << user_id << " which is higher than " << instruction_id; continue; @@ -151,13 +151,13 @@ HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1, HloInstruction* fused = instr2; // Make sure that if only one of the instructions is a fusion, or if only one // of the instructions is a multi-output fusion, it's what will be fused into. - if (fused->opcode() == HloOpcode::kFusion) { + if (HloPredicateIsOp(fused)) { std::swap(remaining, fused); } if (fused->IsMultiOutputFusion()) { std::swap(remaining, fused); } - if (fused->opcode() == HloOpcode::kFusion) { + if (HloPredicateIsOp(fused)) { remaining->MergeFusionInstructionIntoMultiOutput(fused); } else { remaining->FuseInstructionIntoMultiOutput(fused); @@ -187,7 +187,7 @@ HloInstruction* MultiOutputFusion::CreateFusion(HloInstruction* base, bool MultiOutputFusion::IsProfitableOperand(HloInstruction* instr) { // kConstant instruction will not have memory reads, so it won't be a profit // source. Skip them. - if (instr->opcode() == HloOpcode::kConstant && + if (HloPredicateIsOp(instr) && ShapeUtil::IsEffectiveScalar(instr->shape())) { return false; } @@ -302,7 +302,7 @@ void MultiOutputFusion::UpdateAfterFuse( bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1, HloInstruction* instr2) { - if (instr1->opcode() != HloOpcode::kFusion) { + if (HloPredicateIsNotOp(instr1)) { return false; } return LegalToFuseMainConstraints(instr1, instr2); @@ -328,7 +328,7 @@ bool MultiOutputFusion::LegalToFuseMainConstraints(HloInstruction* instr1, return false; } for (auto user : instr->users()) { - if (user->opcode() != HloOpcode::kGetTupleElement) { + if (HloPredicateIsNotOp(user)) { return true; } } @@ -424,13 +424,13 @@ bool MultiOutputFusion::Perform() { VLOG(1) << "Fuse!"; VLOG(2) << "Before multi_output_fusion:"; VLOG(2) << "instr1: " << instr1->ToString(); - if (instr1->opcode() == HloOpcode::kFusion) { + if (HloPredicateIsOp(instr1)) { VLOG(2) << "\n" << instr1->fused_instructions_computation()->ToString( HloPrintOptions().set_indent_amount(1)); } VLOG(2) << "instr2: " << instr2->ToString(); - if (instr2->opcode() == HloOpcode::kFusion) { + if (HloPredicateIsOp(instr2)) { VLOG(2) << "\n" << instr2->fused_instructions_computation()->ToString( HloPrintOptions().set_indent_amount(1)); diff --git a/xla/service/p2p_schedule_preparation.cc b/xla/service/p2p_schedule_preparation.cc index 4000b87c33b05..e1a19f01f9458 100644 --- a/xla/service/p2p_schedule_preparation.cc +++ b/xla/service/p2p_schedule_preparation.cc @@ -431,7 +431,7 @@ absl::Status MayAddWhileOpToPipelinedGroup(HloInstruction* while_op, int pipelined_group = 0; // Check whether the while-op init contains a token from a Send result. for (auto hlo : while_op->while_init()->operands()) { - if (hlo->opcode() != HloOpcode::kSendDone) { + if (HloPredicateIsNotOp(hlo)) { continue; } int64_t channel_id = hlo->channel_id().value(); @@ -571,7 +571,7 @@ absl::Status GatherP2PGroupsAndCollectiveInfo( collective_in_computation[computation] = true; } - if (hlo->opcode() == HloOpcode::kWhile) { + if (HloPredicateIsOp(hlo)) { // The pipelined Recv-done/Send-done appears after the while-op. As // such, the pipelined group hasn't been constructed at this point. // Keep the while-op and add to the pipelined group later. @@ -800,7 +800,7 @@ absl::Status LinearizeCollectivesWithOtherP2P( if (!MayInvokeCollectiveOp(hlo, collective_in_computation)) { continue; } - if (hlo->opcode() == HloOpcode::kWhile && + if (HloPredicateIsOp(hlo) && group.kind == P2PGroupKind::kPipelined && group.GetWhileOp() == hlo) { // This is the while-op for chain A. No need to add control dependence. continue; @@ -852,7 +852,7 @@ absl::Status LinearizeCollectivesWithPipelinedP2PChild( if (IsP2POp(hlo) && opcode != HloOpcode::kSendDone) { continue; } - if (hlo->opcode() == HloOpcode::kSendDone) { + if (HloPredicateIsOp(hlo)) { auto group_it = p2p_group_map.find(hlo->channel_id().value()); if (group_it == p2p_group_map.end()) { continue; diff --git a/xla/service/scan_loop_accumulator_input_unification_test.cc b/xla/service/scan_loop_accumulator_input_unification_test.cc index a8a1911663eb1..5c925dc3cfa81 100644 --- a/xla/service/scan_loop_accumulator_input_unification_test.cc +++ b/xla/service/scan_loop_accumulator_input_unification_test.cc @@ -38,7 +38,7 @@ using ScanLoopAccumulatorInputUnificationTest = HloTestBase; HloInstruction* GetTopLevelWhileInstruction(HloModule* module) { for (HloInstruction* instr : module->entry_computation()->MakeInstructionPostOrder()) { - if (instr->opcode() == HloOpcode::kWhile) { + if (HloPredicateIsOp(instr)) { return instr; } } @@ -122,7 +122,7 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, UnifyAccumulatorInput) { // Index 2 and 3 of the while are replaced with the input arrays. for (HloInstruction* instr : module->entry_computation()->MakeInstructionPostOrder()) { - if (instr->opcode() == HloOpcode::kWhile) { + if (HloPredicateIsOp(instr)) { EXPECT_EQ(instr->while_init()->operand(2)->opcode(), HloOpcode::kConstant); } @@ -224,7 +224,7 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, UnifyAccumulatorInput2) { // Index 2 and 3 of the while are replaced with the input arrays. for (HloInstruction* instr : module->entry_computation()->MakeInstructionPostOrder()) { - if (instr->opcode() == HloOpcode::kWhile) { + if (HloPredicateIsOp(instr)) { EXPECT_EQ(instr->while_init()->operand(2)->opcode(), HloOpcode::kConstant); EXPECT_EQ(instr->while_init()->operand(3)->opcode(), @@ -466,7 +466,7 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, MultipleUsersInput) { // Only index 2 is replaced with the array. for (HloInstruction* instr : module->entry_computation()->MakeInstructionPostOrder()) { - if (instr->opcode() == HloOpcode::kWhile) { + if (HloPredicateIsOp(instr)) { EXPECT_EQ(instr->while_init()->operand(2)->opcode(), HloOpcode::kConstant); } diff --git a/xla/service/sharding_remover_test.cc b/xla/service/sharding_remover_test.cc index 110b3bc416867..41b314c4048d8 100644 --- a/xla/service/sharding_remover_test.cc +++ b/xla/service/sharding_remover_test.cc @@ -52,7 +52,7 @@ ENTRY entry { EXPECT_EQ(parameter->user_count(), 2); bool replaced = false; for (HloInstruction* user : parameter->users()) { - if (user->opcode() == HloOpcode::kCopy) { + if (HloPredicateIsOp(user)) { replaced = true; EXPECT_THAT(user, op::Copy(op::Parameter())); break; diff --git a/xla/service/while_loop_concat_code_motion.cc b/xla/service/while_loop_concat_code_motion.cc index e1aa072d30ecc..604139a23ba57 100644 --- a/xla/service/while_loop_concat_code_motion.cc +++ b/xla/service/while_loop_concat_code_motion.cc @@ -279,7 +279,7 @@ std::optional> GetOperandConcatDim( const HloInstruction* hlo, int64_t operand_index, int64_t hlo_concat_dim, bool hlo_inserted_concat_dim, const ConcatGroup* combined_operand_group = nullptr) { - if (hlo->IsElementwise() || hlo->opcode() == HloOpcode::kAllReduce) { + if (hlo->IsElementwise() || HloPredicateIsOp(hlo)) { return std::pair(hlo_concat_dim, hlo_inserted_concat_dim); } int64_t operand_concat_dim = -1; @@ -288,7 +288,7 @@ std::optional> GetOperandConcatDim( combined_operand_group == nullptr ? hlo->operand(operand_index)->shape() : combined_operand_group->elements.back()->shape(); - if (hlo->opcode() == HloOpcode::kBroadcast) { + if (HloPredicateIsOp(hlo)) { operand_concat_dim = 0; operand_inserted_concat_dim = true; // Try to place operand_concat_dim adjacent to dims the same way as the @@ -311,7 +311,7 @@ std::optional> GetOperandConcatDim( min_dist_to_concat_dim = hlo->dimensions(i) - hlo_concat_dim; } } - } else if (hlo->opcode() == HloOpcode::kReduce) { + } else if (HloPredicateIsOp(hlo)) { if (operand_index != 0) { return std::nullopt; } @@ -327,7 +327,7 @@ std::optional> GetOperandConcatDim( operand_concat_dim++; } } - } else if (hlo->opcode() == HloOpcode::kReshape) { + } else if (HloPredicateIsOp(hlo)) { int64_t i = 0; int64_t j = 0; operand_inserted_concat_dim = false; @@ -375,7 +375,7 @@ std::optional> GetOperandConcatDim( void ModifyHloPropertiesForConcatShape(const ConcatGroup& group, HloInstruction* hlo) { *hlo->mutable_shape() = group.GetConcatShape(); - if (hlo->opcode() == HloOpcode::kBroadcast) { + if (HloPredicateIsOp(hlo)) { // Use the last element to infer the operand concat dim, since the first // element's operand might have been rewriten. auto operand_dim = GetOperandConcatDim( @@ -408,7 +408,7 @@ void ModifyHloPropertiesForConcatShape(const ConcatGroup& group, } } *hlo->mutable_dimensions() = std::move(dims); - } else if (hlo->opcode() == HloOpcode::kReduce) { + } else if (HloPredicateIsOp(hlo)) { auto operand_dim = GetOperandConcatDim( group.elements.back(), 0, group.concat_dim, group.inserted_concat_dim); int64_t operand_concat_dim = operand_dim->first; @@ -500,7 +500,7 @@ bool GroupHlosForConcat( continue; } if (absl::c_all_of(hlos, [&](const HloInstruction* element) { - return element->opcode() == HloOpcode::kGetTupleElement && + return HloPredicateIsOp(element) && element->operand(0) == body->parameter_instruction(0); })) { group_is_param_gtes = true; @@ -530,7 +530,7 @@ bool GroupHlosForConcat( /*layout_sensitive=*/false)) { return true; } - if (element->opcode() == HloOpcode::kReduce && + if (HloPredicateIsOp(element) && (element->operand_count() != 2 || element->operand(1) != hlos[0]->operand(1))) { return true; @@ -641,7 +641,7 @@ bool GroupHlosForConcat( std::vector TupleElementsUsedInCond(HloInstruction* loop) { std::vector result(loop->shape().tuple_shapes_size(), false); for (auto user : loop->while_condition()->parameter_instruction(0)->users()) { - if (user->opcode() != HloOpcode::kGetTupleElement) { + if (HloPredicateIsNotOp(user)) { absl::c_fill(result, true); return result; } @@ -696,7 +696,7 @@ absl::Status RemoveCopiesFromRoot(HloComputation* body) { CHECK_EQ(root->opcode(), HloOpcode::kTuple); for (int64_t i = 0; i < root->operand_count(); ++i) { auto copy = root->mutable_operand(i); - if (copy->opcode() == HloOpcode::kCopy) { + if (HloPredicateIsOp(copy)) { TF_RETURN_IF_ERROR(root->ReplaceOperandWith(i, copy->mutable_operand(0))); } } @@ -798,14 +798,14 @@ absl::Status RewriteLoopWithConcatGroups( continue; } const auto& group = groups.GetGroup(group_and_index->first); - if (hlo->opcode() == HloOpcode::kSlice) { + if (HloPredicateIsOp(hlo)) { // We could just replace hlo with its operand; however, to follow the // practice of using the first element as full data, we defer that // replacement. slices_to_remove.push_back(hlo); } else { int64_t operand_count_to_adjust = hlo->operand_count(); - if (hlo->opcode() == HloOpcode::kReduce) { + if (HloPredicateIsOp(hlo)) { CHECK_EQ(operand_count_to_adjust, 2); operand_count_to_adjust = 1; } @@ -915,7 +915,8 @@ absl::Status RewriteLoopWithConcatGroups( continue; } const auto& group_and_index = groups.GetGroupIndex(hlo); - if ((!group_and_index.has_value() || hlo->opcode() == HloOpcode::kReduce) && + if ((!group_and_index.has_value() || + HloPredicateIsOp(hlo)) && hlo != body->root_instruction()) { auto operands = hlo->operands(); if (group_and_index.has_value()) { @@ -949,7 +950,8 @@ absl::StatusOr RunOnLoop(HloInstruction* loop, auto body = loop->while_body(); auto param = body->parameter_instruction(0); auto root = body->root_instruction(); - if (!param->shape().IsTuple() || root->opcode() != HloOpcode::kTuple) { + if (!param->shape().IsTuple() || + HloPredicateIsNotOp(root)) { return false; } std::vector gtes(param->shape().tuple_shapes_size(), @@ -957,7 +959,7 @@ absl::StatusOr RunOnLoop(HloInstruction* loop, ConcatGroups groups; auto indices_used_in_cond = TupleElementsUsedInCond(loop); for (auto user : param->users()) { - if (user->opcode() != HloOpcode::kGetTupleElement) { + if (HloPredicateIsNotOp(user)) { // Unhandled user opcode. return false; } @@ -977,7 +979,7 @@ absl::StatusOr RunOnLoop(HloInstruction* loop, for (int64_t i = 0; i < body_instructions.size(); ++i) { auto hlo = body_instructions[i]; topological_order[hlo] = i; - if (hlo->opcode() == HloOpcode::kConcatenate && + if (HloPredicateIsOp(hlo) && hlo->operand_count() >= min_operand_count_to_optimize) { concats.push_back(hlo); } @@ -1026,7 +1028,7 @@ absl::StatusOr WhileLoopConcatCodeMotion::Run( for (HloComputation* comp : module->MakeComputationPostOrder(execution_threads)) { for (HloInstruction* hlo : comp->MakeInstructionPostOrder()) { - if (hlo->opcode() == HloOpcode::kWhile) { + if (HloPredicateIsOp(hlo)) { TF_ASSIGN_OR_RETURN(bool loop_changed, RunOnLoop(hlo, min_operand_count_to_optimize_)); changed |= loop_changed; diff --git a/xla/service/while_loop_constant_sinking_test.cc b/xla/service/while_loop_constant_sinking_test.cc index 2cfd69a9254e8..b5c3d59783182 100644 --- a/xla/service/while_loop_constant_sinking_test.cc +++ b/xla/service/while_loop_constant_sinking_test.cc @@ -295,7 +295,7 @@ ENTRY entry { op::Tuple(op::GetTupleElement(), op::GetTupleElement(), op::GetTupleElement())); for (const HloInstruction* inst : while_body->instructions()) { - if (inst->opcode() == HloOpcode::kConstant) { + if (HloPredicateIsOp(inst)) { EXPECT_GT(inst->user_count(), 0); } } @@ -422,7 +422,7 @@ ENTRY entry { auto* while_condition = module->GetComputationWithName("condition.sunk"); EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant())); for (const HloInstruction* inst : while_condition->instructions()) { - if (inst->opcode() == HloOpcode::kConstant) { + if (HloPredicateIsOp(inst)) { EXPECT_GT(inst->user_count(), 0); } } diff --git a/xla/service/while_loop_invariant_code_motion_test.cc b/xla/service/while_loop_invariant_code_motion_test.cc index eadb19462118f..392336c32b52b 100644 --- a/xla/service/while_loop_invariant_code_motion_test.cc +++ b/xla/service/while_loop_invariant_code_motion_test.cc @@ -49,7 +49,7 @@ static void FindOnlyWhileInstruction(HloComputation* computation, HloInstruction** while_instruction) { *while_instruction = nullptr; for (auto* instr : computation->instructions()) { - if (instr->opcode() == HloOpcode::kWhile) { + if (HloPredicateIsOp(instr)) { ASSERT_EQ(*while_instruction, nullptr); *while_instruction = instr; }