Skip to content

Commit 1b15816

Browse files
jreifferstensorflower-gardener
authored andcommitted
PR #25275: Propagate loop metadata in double_buffer_loop_unrolling.
Imported from GitHub PR openxla/xla#25275 Currently, this pass drops the `known_induction_variable` and `known_init_step` fields in `WhileLoopBackendConfig`. This change keeps them if they were present before and updates them with the updated value, just like `known_trip_count`. This is only done for double buffered loops, not for fully unrolled ones. Copybara import of the project: -- 6bd44597c2e1dba443a367edf4a21024f6248cb6 by Johannes Reifferscheid <jreiffers@nvidia.com>: Keep loop metadata after loop unrolling. -- 369ad9bedd031d1e1fd6abacb7a5002f290c72f2 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix init update and add tests. Merging this change closes #25275 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#25275 from jreiffers:keep-metadata 369ad9bedd031d1e1fd6abacb7a5002f290c72f2 PiperOrigin-RevId: 750112934
1 parent b0f1414 commit 1b15816

File tree

3 files changed

+108
-2
lines changed

3 files changed

+108
-2
lines changed

tensorflow/tools/pip_package/setup.py.tpl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ REQUIRED_PACKAGES = [
9191
'libclang >= 13.0.0',
9292
'opt_einsum >= 2.3.2',
9393
'packaging',
94-
'protobuf>=4.21.6,<6.0.0dev',
94+
'protobuf>=4.21.6,<7.0.0dev',
9595
'requests >= 2.21.0, < 3',
9696
'setuptools',
9797
'six >= 1.12.0',

third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,21 @@ absl::StatusOr<bool> DoubleBufferingUnroll(HloInstruction* while_instr,
519519

520520
WhileLoopBackendConfig new_config;
521521
new_config.mutable_known_trip_count()->set_n(exact_trip_count / 2);
522+
523+
// Keep known induction variable metadata if it was present before.
524+
if (config.has_known_induction_variable()) {
525+
*new_config.mutable_known_induction_variable() =
526+
config.known_induction_variable();
527+
}
528+
529+
// Update the init/step metadata if it was present before.
530+
if (config.has_known_init_step()) {
531+
int64_t step = config.known_init_step().step();
532+
new_config.mutable_known_init_step()->set_step(step * 2);
533+
new_config.mutable_known_init_step()->set_init(
534+
config.known_init_step().init() + (peel_one_iteration ? step : 0));
535+
}
536+
522537
TF_RETURN_IF_ERROR(while_instr->set_backend_config(new_config));
523538
return true; // changed
524539
}

third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ ENTRY main {
9292
param_0 = f32[] parameter(0)
9393
param_2 = s32[] constant(0)
9494
tuple = (f32[], s32[]) tuple(param_0, param_2)
95-
ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
95+
ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body,
96+
backend_config={"known_trip_count":{"n":"10"},
97+
"known_induction_variable":{"tuple_index":"1"}}
9698
})";
9799

98100
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
@@ -110,6 +112,7 @@ ENTRY main {
110112
WhileLoopBackendConfig config,
111113
while_instruction->backend_config<WhileLoopBackendConfig>());
112114
EXPECT_EQ(config.known_trip_count().n(), 5);
115+
EXPECT_EQ(config.known_induction_variable().tuple_index(), 1);
113116
EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
114117
HloOpcode::kAllReduceStart),
115118
2);
@@ -1397,6 +1400,94 @@ ENTRY main {
13971400
EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(false));
13981401
}
13991402

1403+
TEST_F(GpuLoopDoubleBufferTransformerTest, UpdateInitStepOddTripCount) {
1404+
absl::string_view kModuleString = R"(
1405+
HloModule m
1406+
condition {
1407+
input_tuple = (s32[]) parameter(0)
1408+
iter = s32[] get-tuple-element(input_tuple), index=0
1409+
c12 = s32[] constant(12)
1410+
ROOT continue = pred[] compare(iter, c12), direction=LT
1411+
}
1412+
1413+
body {
1414+
input_tuple = (s32[]) parameter(0)
1415+
iter = s32[] get-tuple-element(input_tuple), index=0
1416+
c2 = s32[] constant(2)
1417+
next_iter = s32[] add(iter, c2)
1418+
ROOT output_tuple = (s32[]) tuple(next_iter)
1419+
}
1420+
1421+
ENTRY main {
1422+
c3 = s32[] constant(3)
1423+
tuple = (s32[]) tuple(c3)
1424+
// Values: 3, 5, 7, 9, 11
1425+
ROOT while = (s32[]) while(tuple), condition=condition, body=body,
1426+
backend_config={"known_trip_count":{"n":"5"},
1427+
"known_init_step":{"init":"3","step":"2"}}
1428+
})";
1429+
1430+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
1431+
ParseAndReturnVerifiedModule(kModuleString));
1432+
DoubleBufferLoopUnrolling unroller(
1433+
DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
1434+
TF_ASSERT_OK_AND_ASSIGN(bool changed, unroller.Run(module.get()));
1435+
EXPECT_TRUE(changed);
1436+
1437+
HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
1438+
*module->entry_computation(), HloOpcode::kWhile);
1439+
TF_ASSERT_OK_AND_ASSIGN(
1440+
WhileLoopBackendConfig config,
1441+
while_instruction->backend_config<WhileLoopBackendConfig>());
1442+
EXPECT_EQ(config.known_trip_count().n(), 2);
1443+
EXPECT_EQ(config.known_init_step().init(), 5);
1444+
EXPECT_EQ(config.known_init_step().step(), 4);
1445+
}
1446+
1447+
TEST_F(GpuLoopDoubleBufferTransformerTest, UpdateInitStepEvenTripCount) {
1448+
absl::string_view kModuleString = R"(
1449+
HloModule m
1450+
condition {
1451+
input_tuple = (s32[]) parameter(0)
1452+
iter = s32[] get-tuple-element(input_tuple), index=0
1453+
c14 = s32[] constant(14)
1454+
ROOT continue = pred[] compare(iter, c14), direction=LT
1455+
}
1456+
1457+
body {
1458+
input_tuple = (s32[]) parameter(0)
1459+
iter = s32[] get-tuple-element(input_tuple), index=0
1460+
c2 = s32[] constant(2)
1461+
next_iter = s32[] add(iter, c2)
1462+
ROOT output_tuple = (s32[]) tuple(next_iter)
1463+
}
1464+
1465+
ENTRY main {
1466+
c3 = s32[] constant(3)
1467+
tuple = (s32[]) tuple(c3)
1468+
// Values: 3, 5, 7, 9, 11, 13
1469+
ROOT while = (s32[]) while(tuple), condition=condition, body=body,
1470+
backend_config={"known_trip_count":{"n":"6"},
1471+
"known_init_step":{"init":"3","step":"2"}}
1472+
})";
1473+
1474+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
1475+
ParseAndReturnVerifiedModule(kModuleString));
1476+
DoubleBufferLoopUnrolling unroller(
1477+
DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
1478+
TF_ASSERT_OK_AND_ASSIGN(bool changed, unroller.Run(module.get()));
1479+
EXPECT_TRUE(changed);
1480+
1481+
HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
1482+
*module->entry_computation(), HloOpcode::kWhile);
1483+
TF_ASSERT_OK_AND_ASSIGN(
1484+
WhileLoopBackendConfig config,
1485+
while_instruction->backend_config<WhileLoopBackendConfig>());
1486+
EXPECT_EQ(config.known_trip_count().n(), 3);
1487+
EXPECT_EQ(config.known_init_step().init(), 3);
1488+
EXPECT_EQ(config.known_init_step().step(), 4);
1489+
}
1490+
14001491
} // namespace
14011492
} // namespace gpu
14021493
} // namespace xla

0 commit comments

Comments
 (0)