Skip to content

Commit

Permalink
[Cleanup] Use HloPredicateIs(Not)Op
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707921675
  • Loading branch information
frgossen authored and Google-ML-Automation committed Dec 19, 2024
1 parent 88045f6 commit 47c5b30
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 14 deletions.
2 changes: 1 addition & 1 deletion xla/service/gpu/transforms/copy_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ absl::StatusOr<bool> CopyFusion::DoCopyFusion(HloComputation* computation) {
}
HloInstruction* root = fused_computation->root_instruction();
if (IsReductionFromOrToContiguousDimensions(*root, device_description_) ||
root->opcode() == HloOpcode::kScatter ||
HloPredicateIsOp<HloOpcode::kScatter>(root) ||
(hlo->IsMultiOutputFusion() &&
absl::c_all_of(root->operands(),
HloPredicateIsOp<HloOpcode::kSlice>))) {
Expand Down
12 changes: 6 additions & 6 deletions xla/service/gpu/transforms/double_buffer_loop_unrolling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ absl::Status SetSendRecvValidationForPeeledInstr(HloInstruction* new_instr,
TF_RET_CHECK(
new_instr->opcode() == old_instr->opcode() &&
"cloned instruction and original instruction have different opcodes");
if (!HloPredicateIsOp<HloOpcode::kCollectivePermute,
HloOpcode::kCollectivePermuteStart, HloOpcode::kSend,
HloOpcode::kRecv>(old_instr)) {
if (HloPredicateIsNotOp<HloOpcode::kCollectivePermute,
HloOpcode::kCollectivePermuteStart, HloOpcode::kSend,
HloOpcode::kRecv>(old_instr)) {
return absl::OkStatus();
}

Expand Down Expand Up @@ -188,9 +188,9 @@ absl::Status SetSendRecvValidation(HloInstruction* cp1, HloInstruction* cp2,
TF_RET_CHECK(
cp2->opcode() == cp1->opcode() &&
"cloned instruction and original instruction have different opcodes");
if (!HloPredicateIsOp<HloOpcode::kCollectivePermute,
HloOpcode::kCollectivePermuteStart, HloOpcode::kSend,
HloOpcode::kRecv>(cp1)) {
if (HloPredicateIsNotOp<HloOpcode::kCollectivePermute,
HloOpcode::kCollectivePermuteStart, HloOpcode::kSend,
HloOpcode::kRecv>(cp1)) {
return absl::OkStatus();
}
const auto& attribute_map = cp2->frontend_attributes().map();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace {
using ::tsl::testing::IsOkAndHolds;

bool HasTritonBlockLevelFusionConfig(const HloInstruction* fusion) {
return fusion->opcode() == HloOpcode::kFusion &&
return HloPredicateIsOp<HloOpcode::kFusion>(fusion) &&
fusion->has_backend_config() &&
fusion->backend_config<GpuBackendConfig>().ok() &&
fusion->backend_config<GpuBackendConfig>()
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/transforms/rename_fusions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ absl::StatusOr<bool> RenameFusions::Run(
const absl::flat_hash_set<absl::string_view>& execution_threads) {
for (HloComputation* computation : module->MakeNonfusionComputations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() != HloOpcode::kFusion ||
if (HloPredicateIsNotOp<HloOpcode::kFusion>(instruction) ||
instruction->fusion_kind() == HloInstruction::FusionKind::kCustom) {
continue;
}
Expand Down
8 changes: 3 additions & 5 deletions xla/service/gpu/transforms/windowed_einsum_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,9 @@ absl::StatusOr<bool> ShiftDequantizationF8(HloComputation* while_body) {
HloInstruction* operand = param_tuple->mutable_operand(k);
// Capture bitcast, broadcast, copy, reshape and transpose ops between
// dequantization and the loop.
while (operand->opcode() == HloOpcode::kBitcast ||
operand->opcode() == HloOpcode::kBroadcast ||
operand->opcode() == HloOpcode::kCopy ||
operand->opcode() == HloOpcode::kReshape ||
operand->opcode() == HloOpcode::kTranspose) {
while (HloPredicateIsOp<HloOpcode::kBitcast, HloOpcode::kBroadcast,
HloOpcode::kCopy, HloOpcode::kReshape,
HloOpcode::kTranspose>(operand)) {
unaries[k].push_back(operand);
operand = operand->mutable_operand(0);
}
Expand Down

0 comments on commit 47c5b30

Please sign in to comment.