Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions xla/service/gpu/transforms/double_buffer_loop_unrolling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,21 @@ absl::StatusOr<bool> DoubleBufferingUnroll(HloInstruction* while_instr,

WhileLoopBackendConfig new_config;
new_config.mutable_known_trip_count()->set_n(exact_trip_count / 2);

// Keep known induction variable metadata if it was present before.
if (config.has_known_induction_variable()) {
*new_config.mutable_known_induction_variable() =
config.known_induction_variable();
}

// Update the init/step metadata if it was present before.
if (config.has_known_init_step()) {
int64_t step = config.known_init_step().step();
new_config.mutable_known_init_step()->set_step(step * 2);
new_config.mutable_known_init_step()->set_init(
config.known_init_step().init() + (peel_one_iteration ? step : 0));
}

TF_RETURN_IF_ERROR(while_instr->set_backend_config(new_config));
return true; // changed
}
Expand Down
93 changes: 92 additions & 1 deletion xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ ENTRY main {
param_0 = f32[] parameter(0)
param_2 = s32[] constant(0)
tuple = (f32[], s32[]) tuple(param_0, param_2)
ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body,
backend_config={"known_trip_count":{"n":"10"},
"known_induction_variable":{"tuple_index":"1"}}
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
Expand All @@ -110,6 +112,7 @@ ENTRY main {
WhileLoopBackendConfig config,
while_instruction->backend_config<WhileLoopBackendConfig>());
EXPECT_EQ(config.known_trip_count().n(), 5);
EXPECT_EQ(config.known_induction_variable().tuple_index(), 1);
EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
HloOpcode::kAllReduceStart),
2);
Expand Down Expand Up @@ -1397,6 +1400,94 @@ ENTRY main {
EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(false));
}

TEST_F(GpuLoopDoubleBufferTransformerTest, UpdateInitStepOddTripCount) {
absl::string_view kModuleString = R"(
HloModule m
condition {
input_tuple = (s32[]) parameter(0)
iter = s32[] get-tuple-element(input_tuple), index=0
c12 = s32[] constant(12)
ROOT continue = pred[] compare(iter, c12), direction=LT
}

body {
input_tuple = (s32[]) parameter(0)
iter = s32[] get-tuple-element(input_tuple), index=0
c2 = s32[] constant(2)
next_iter = s32[] add(iter, c2)
ROOT output_tuple = (s32[]) tuple(next_iter)
}

ENTRY main {
c3 = s32[] constant(3)
tuple = (s32[]) tuple(c3)
// Values: 3, 5, 7, 9, 11
ROOT while = (s32[]) while(tuple), condition=condition, body=body,
backend_config={"known_trip_count":{"n":"5"},
"known_init_step":{"init":"3","step":"2"}}
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnVerifiedModule(kModuleString));
DoubleBufferLoopUnrolling unroller(
DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
TF_ASSERT_OK_AND_ASSIGN(bool changed, unroller.Run(module.get()));
EXPECT_TRUE(changed);

HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
*module->entry_computation(), HloOpcode::kWhile);
TF_ASSERT_OK_AND_ASSIGN(
WhileLoopBackendConfig config,
while_instruction->backend_config<WhileLoopBackendConfig>());
EXPECT_EQ(config.known_trip_count().n(), 2);
EXPECT_EQ(config.known_init_step().init(), 5);
EXPECT_EQ(config.known_init_step().step(), 4);
}

TEST_F(GpuLoopDoubleBufferTransformerTest, UpdateInitStepEvenTripCount) {
absl::string_view kModuleString = R"(
HloModule m
condition {
input_tuple = (s32[]) parameter(0)
iter = s32[] get-tuple-element(input_tuple), index=0
c14 = s32[] constant(14)
ROOT continue = pred[] compare(iter, c14), direction=LT
}

body {
input_tuple = (s32[]) parameter(0)
iter = s32[] get-tuple-element(input_tuple), index=0
c2 = s32[] constant(2)
next_iter = s32[] add(iter, c2)
ROOT output_tuple = (s32[]) tuple(next_iter)
}

ENTRY main {
c3 = s32[] constant(3)
tuple = (s32[]) tuple(c3)
// Values: 3, 5, 7, 9, 11, 13
ROOT while = (s32[]) while(tuple), condition=condition, body=body,
backend_config={"known_trip_count":{"n":"6"},
"known_init_step":{"init":"3","step":"2"}}
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnVerifiedModule(kModuleString));
DoubleBufferLoopUnrolling unroller(
DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
TF_ASSERT_OK_AND_ASSIGN(bool changed, unroller.Run(module.get()));
EXPECT_TRUE(changed);

HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
*module->entry_computation(), HloOpcode::kWhile);
TF_ASSERT_OK_AND_ASSIGN(
WhileLoopBackendConfig config,
while_instruction->backend_config<WhileLoopBackendConfig>());
EXPECT_EQ(config.known_trip_count().n(), 3);
EXPECT_EQ(config.known_init_step().init(), 3);
EXPECT_EQ(config.known_init_step().step(), 4);
}

} // namespace
} // namespace gpu
} // namespace xla