Skip to content

Commit 55d96a7

Browse files
yunxingtensorflower-gardener
authored andcommitted
[XLA] Do not simplify loops with trip count = 1 if there is an infeed in it.
PiperOrigin-RevId: 303217179 Change-Id: Ida39742d25319b878fbc10b675b2133bf2e6d5b4
1 parent 3009664 commit 55d96a7

File tree

3 files changed

+37
-45
lines changed

3 files changed

+37
-45
lines changed

tensorflow/compiler/xla/service/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2228,6 +2228,7 @@ cc_library(
22282228
":while_loop_analysis",
22292229
"//tensorflow/compiler/xla:shape_util",
22302230
"//tensorflow/compiler/xla:statusor",
2231+
"@com_google_absl//absl/algorithm:container",
22312232
"@com_google_absl//absl/container:flat_hash_map",
22322233
"@com_google_absl//absl/container:flat_hash_set",
22332234
"@com_google_absl//absl/strings",

tensorflow/compiler/xla/service/while_loop_simplifier.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,19 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
17+
18+
#include "absl/algorithm/container.h"
1719
#include "absl/container/flat_hash_map.h"
1820
#include "absl/container/flat_hash_set.h"
1921
#include "absl/strings/str_cat.h"
2022
#include "absl/strings/str_join.h"
2123
#include "absl/types/optional.h"
2224
#include "tensorflow/compiler/xla/primitive_util.h"
2325
#include "tensorflow/compiler/xla/service/call_inliner.h"
26+
#include "tensorflow/compiler/xla/service/hlo_computation.h"
2427
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
2528
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
29+
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
2630
#include "tensorflow/compiler/xla/service/hlo_query.h"
2731
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
2832
#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
@@ -1010,6 +1014,35 @@ StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
10101014
continue;
10111015
}
10121016

1017+
// Do not simplify the loop away when there is a side-effectful op,
1018+
// otherwise the infeed op may not inherit the data dependency from
1019+
// the while loop.
1020+
//
1021+
// Example: while_body (param_a) {
1022+
// param_a = parameter(0)
1023+
// infeed2 = infeed()
1024+
// }
1025+
//
1026+
// infeed1 = ...
1027+
// while = while(infeed1), body=while_body // infeed2 has implicit
1028+
// dependency on infeed1.
1029+
//
1030+
// After simplification:
1031+
//
1032+
// infeed1 = ...
1033+
// infeed2 = infeed() // no dependency between infeed1 and infeed2. infeed1
1034+
// // can be scheduled after infeed2.
1035+
//
1036+
bool has_side_effects = absl::c_any_of(
1037+
while_op->called_computations(), [](const HloComputation* computation) {
1038+
return computation->HasSideEffect();
1039+
});
1040+
if (has_side_effects) {
1041+
VLOG(2) << "Not attempting to simplify while loop because it contains a "
1042+
"side-effecting node: "
1043+
<< while_op->ToShortString();
1044+
continue;
1045+
}
10131046
TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op));
10141047
changed |= result;
10151048

tensorflow/compiler/xla/service/while_loop_simplifier_test.cc

Lines changed: 3 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) {
209209
EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
210210
}
211211

212-
// We can simplify loops whose bodies contain infeed or other side-effecting
213-
// instructions other than send/recv.
212+
// We can't simplify loops whose bodies contain infeed or other side-effecting
213+
// instructions.
214214
TEST_F(WhileLoopSimplifierTest, LoopWithInfeedSimplified) {
215215
auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1);
216216
HloComputation* computation = m->entry_computation();
@@ -220,8 +220,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithInfeedSimplified) {
220220
auto token = while_body->AddInstruction(HloInstruction::CreateToken());
221221
while_body->AddInstruction(HloInstruction::CreateInfeed(
222222
ShapeUtil::MakeShape(F32, {1}), token, "config"));
223-
EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
224-
EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple());
223+
EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
225224
}
226225

227226
// We don't simplify trip-count-1 loops whose *conditions* contain infeed or
@@ -445,47 +444,6 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) {
445444
op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1)));
446445
}
447446

448-
// Check that we can remove unused loop operands even if the loop contains a
449-
// side-effecting instruction.
450-
TEST_F(WhileLoopSimplifierTest,
451-
RemoveUnusedLoopOperandsDespiteSideEffectingOps) {
452-
const string hlo_string = R"(
453-
HloModule RemoveUnusedOperands
454-
body {
455-
loop_var = (s32[]) parameter(0)
456-
gte0 = s32[] get-tuple-element(loop_var), index=0
457-
token0 = token[] after-all()
458-
unused = ((s32[], pred[]), token[]) infeed(token0)
459-
ROOT tuple = (s32[]) tuple(gte0)
460-
}
461-
cond {
462-
loop_var = (s32[]) parameter(0)
463-
ROOT constant = pred[] constant(true)
464-
}
465-
ENTRY RemoveUnusedOperands {
466-
x = s32[] parameter(0)
467-
tuple.1 = (s32[]) tuple(s32[] x)
468-
ROOT while = (s32[]) while((s32[]) tuple.1),
469-
condition=cond, body=body
470-
}
471-
)";
472-
473-
auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
474-
EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
475-
476-
// The original while instruction is still left in the module as a dead
477-
// instruction, find a while instruction with a different name as the new
478-
// while instruction.
479-
const auto& instrs = m->entry_computation()->instructions();
480-
HloInstruction* new_while_op =
481-
*absl::c_find_if(instrs, [&](const HloInstruction* instr) {
482-
return (instr->opcode() == HloOpcode::kWhile &&
483-
instr->name() != "while");
484-
});
485-
EXPECT_TRUE(ShapeUtil::IsEmptyTuple(new_while_op->shape()))
486-
<< new_while_op->shape().ToString();
487-
}
488-
489447
TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) {
490448
const string hlo_string = R"(
491449
HloModule BodyHasNonTupleRoot

0 commit comments

Comments
 (0)