@@ -92,7 +92,9 @@ ENTRY main {
9292 param_0 = f32[] parameter(0)
9393 param_2 = s32[] constant(0)
9494 tuple = (f32[], s32[]) tuple(param_0, param_2)
95- ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
95+ ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body,
96+ backend_config={"known_trip_count":{"n":"10"},
97+ "known_induction_variable":{"tuple_index":"1"}}
9698})" ;
9799
98100 TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<xla::HloModule> module ,
@@ -110,6 +112,7 @@ ENTRY main {
110112 WhileLoopBackendConfig config,
111113 while_instruction->backend_config <WhileLoopBackendConfig>());
112114 EXPECT_EQ (config.known_trip_count ().n (), 5 );
115+ EXPECT_EQ (config.known_induction_variable ().tuple_index (), 1 );
113116 EXPECT_EQ (CountInstructions ((*while_instruction->while_body ()),
114117 HloOpcode::kAllReduceStart ),
115118 2 );
@@ -1397,6 +1400,94 @@ ENTRY main {
13971400 EXPECT_THAT (double_buffer.Run (module .get ()), IsOkAndHolds (false ));
13981401}
13991402
1403+ TEST_F (GpuLoopDoubleBufferTransformerTest, UpdateInitStepOddTripCount) {
1404+ absl::string_view kModuleString = R"(
1405+ HloModule m
1406+ condition {
1407+ input_tuple = (s32[]) parameter(0)
1408+ iter = s32[] get-tuple-element(input_tuple), index=0
1409+ c12 = s32[] constant(12)
1410+ ROOT continue = pred[] compare(iter, c12), direction=LT
1411+ }
1412+
1413+ body {
1414+ input_tuple = (s32[]) parameter(0)
1415+ iter = s32[] get-tuple-element(input_tuple), index=0
1416+ c2 = s32[] constant(2)
1417+ next_iter = s32[] add(iter, c2)
1418+ ROOT output_tuple = (s32[]) tuple(next_iter)
1419+ }
1420+
1421+ ENTRY main {
1422+ c3 = s32[] constant(3)
1423+ tuple = (s32[]) tuple(c3)
1424+ // Values: 3, 5, 7, 9, 11
1425+ ROOT while = (s32[]) while(tuple), condition=condition, body=body,
1426+ backend_config={"known_trip_count":{"n":"5"},
1427+ "known_init_step":{"init":"3","step":"2"}}
1428+ })" ;
1429+
1430+ TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<xla::HloModule> module ,
1431+ ParseAndReturnVerifiedModule (kModuleString ));
1432+ DoubleBufferLoopUnrolling unroller (
1433+ DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer );
1434+ TF_ASSERT_OK_AND_ASSIGN (bool changed, unroller.Run (module .get ()));
1435+ EXPECT_TRUE (changed);
1436+
1437+ HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode (
1438+ *module ->entry_computation (), HloOpcode::kWhile );
1439+ TF_ASSERT_OK_AND_ASSIGN (
1440+ WhileLoopBackendConfig config,
1441+ while_instruction->backend_config <WhileLoopBackendConfig>());
1442+ EXPECT_EQ (config.known_trip_count ().n (), 2 );
1443+ EXPECT_EQ (config.known_init_step ().init (), 5 );
1444+ EXPECT_EQ (config.known_init_step ().step (), 4 );
1445+ }
1446+
1447+ TEST_F (GpuLoopDoubleBufferTransformerTest, UpdateInitStepEvenTripCount) {
1448+ absl::string_view kModuleString = R"(
1449+ HloModule m
1450+ condition {
1451+ input_tuple = (s32[]) parameter(0)
1452+ iter = s32[] get-tuple-element(input_tuple), index=0
1453+ c14 = s32[] constant(14)
1454+ ROOT continue = pred[] compare(iter, c14), direction=LT
1455+ }
1456+
1457+ body {
1458+ input_tuple = (s32[]) parameter(0)
1459+ iter = s32[] get-tuple-element(input_tuple), index=0
1460+ c2 = s32[] constant(2)
1461+ next_iter = s32[] add(iter, c2)
1462+ ROOT output_tuple = (s32[]) tuple(next_iter)
1463+ }
1464+
1465+ ENTRY main {
1466+ c3 = s32[] constant(3)
1467+ tuple = (s32[]) tuple(c3)
1468+ // Values: 3, 5, 7, 9, 11, 13
1469+ ROOT while = (s32[]) while(tuple), condition=condition, body=body,
1470+ backend_config={"known_trip_count":{"n":"6"},
1471+ "known_init_step":{"init":"3","step":"2"}}
1472+ })" ;
1473+
1474+ TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<xla::HloModule> module ,
1475+ ParseAndReturnVerifiedModule (kModuleString ));
1476+ DoubleBufferLoopUnrolling unroller (
1477+ DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer );
1478+ TF_ASSERT_OK_AND_ASSIGN (bool changed, unroller.Run (module .get ()));
1479+ EXPECT_TRUE (changed);
1480+
1481+ HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode (
1482+ *module ->entry_computation (), HloOpcode::kWhile );
1483+ TF_ASSERT_OK_AND_ASSIGN (
1484+ WhileLoopBackendConfig config,
1485+ while_instruction->backend_config <WhileLoopBackendConfig>());
1486+ EXPECT_EQ (config.known_trip_count ().n (), 3 );
1487+ EXPECT_EQ (config.known_init_step ().init (), 3 );
1488+ EXPECT_EQ (config.known_init_step ().step (), 4 );
1489+ }
1490+
14001491} // namespace
14011492} // namespace gpu
14021493} // namespace xla
0 commit comments