diff --git a/xla/service/gpu/transforms/copy_fusion.cc b/xla/service/gpu/transforms/copy_fusion.cc index 1b34fb13a7290..23706a4dbcf14 100644 --- a/xla/service/gpu/transforms/copy_fusion.cc +++ b/xla/service/gpu/transforms/copy_fusion.cc @@ -75,7 +75,7 @@ absl::StatusOr CopyFusion::DoCopyFusion(HloComputation* computation) { } HloInstruction* root = fused_computation->root_instruction(); if (IsReductionFromOrToContiguousDimensions(*root, device_description_) || - root->opcode() == HloOpcode::kScatter || + HloPredicateIsOp(root) || (hlo->IsMultiOutputFusion() && absl::c_all_of(root->operands(), HloPredicateIsOp))) { diff --git a/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc b/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc index 7d217aac5674e..c46c3f53f6a84 100644 --- a/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc +++ b/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc @@ -118,9 +118,9 @@ absl::Status SetSendRecvValidationForPeeledInstr(HloInstruction* new_instr, TF_RET_CHECK( new_instr->opcode() == old_instr->opcode() && "cloned instruction and original instruction have different opcodes"); - if (!HloPredicateIsOp(old_instr)) { + if (HloPredicateIsNotOp(old_instr)) { return absl::OkStatus(); } @@ -188,9 +188,9 @@ absl::Status SetSendRecvValidation(HloInstruction* cp1, HloInstruction* cp2, TF_RET_CHECK( cp2->opcode() == cp1->opcode() && "cloned instruction and original instruction have different opcodes"); - if (!HloPredicateIsOp(cp1)) { + if (HloPredicateIsNotOp(cp1)) { return absl::OkStatus(); } const auto& attribute_map = cp2->frontend_attributes().map(); diff --git a/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc b/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc index d574fc106282a..d78dc65be9772 100644 --- a/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc +++ b/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc @@ -47,7 +47,7 @@ namespace { using ::tsl::testing::IsOkAndHolds; bool HasTritonBlockLevelFusionConfig(const HloInstruction* fusion) { - return fusion->opcode() == HloOpcode::kFusion && + return HloPredicateIsOp(fusion) && fusion->has_backend_config() && fusion->backend_config().ok() && fusion->backend_config() diff --git a/xla/service/gpu/transforms/rename_fusions.cc b/xla/service/gpu/transforms/rename_fusions.cc index 29f3edf968fb3..ac396b3fd5915 100644 --- a/xla/service/gpu/transforms/rename_fusions.cc +++ b/xla/service/gpu/transforms/rename_fusions.cc @@ -78,7 +78,7 @@ absl::StatusOr RenameFusions::Run( const absl::flat_hash_set& execution_threads) { for (HloComputation* computation : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kFusion || + if (HloPredicateIsNotOp(instruction) || instruction->fusion_kind() == HloInstruction::FusionKind::kCustom) { continue; } diff --git a/xla/service/gpu/transforms/windowed_einsum_handler.cc b/xla/service/gpu/transforms/windowed_einsum_handler.cc index 2ffec420c30ae..db84a666394f4 100644 --- a/xla/service/gpu/transforms/windowed_einsum_handler.cc +++ b/xla/service/gpu/transforms/windowed_einsum_handler.cc @@ -86,11 +86,9 @@ absl::StatusOr ShiftDequantizationF8(HloComputation* while_body) { HloInstruction* operand = param_tuple->mutable_operand(k); // Capture bitcast, broadcast, copy, reshape and transpose ops between // dequantization and the loop. - while (operand->opcode() == HloOpcode::kBitcast || - operand->opcode() == HloOpcode::kBroadcast || - operand->opcode() == HloOpcode::kCopy || - operand->opcode() == HloOpcode::kReshape || - operand->opcode() == HloOpcode::kTranspose) { + while (HloPredicateIsOp(operand)) { unaries[k].push_back(operand); operand = operand->mutable_operand(0); }