@@ -21,6 +21,7 @@ limitations under the License.
2121#include < string>
2222#include < vector>
2323
24+ #include " absl/container/btree_map.h"
2425#include " absl/container/flat_hash_map.h"
2526#include " absl/container/flat_hash_set.h"
2627#include " absl/log/check.h"
@@ -48,7 +49,8 @@ namespace {
4849// Given a group of annotated instructions (sources), find all reachable
4950// instructions from them in the same computation.
5051absl::flat_hash_set<HloInstruction*> PropagateAnnotationFromSources (
51- const std::vector<HloInstruction*>& sources, HloComputation* computation) {
52+ const std::vector<HloInstruction*>& sources,
53+ const HloComputation* computation) {
5254 absl::flat_hash_set<HloInstruction*> to_annotate;
5355 auto reachability = HloReachabilityMap::Build (computation);
5456 // worklist contains instructions that can reach any source instruction.
@@ -107,28 +109,13 @@ absl::Status AttachAnnotation(
107109 absl::StrCat (annotation_id) + " to " + std::string (instr->name ()) +
108110 " but it has an existing annotation: " + absl::StrCat (*id));
109111 }
112+ LOG (INFO) << " Propagating annotation " << annotation_id << " to "
113+ << instr->name ();
110114 SetSchedulingAnnotation (instr, annotation_id);
111115 }
112116 return absl::OkStatus ();
113117}
114118
115- absl::StatusOr<bool > PropagateAnnotations (
116- HloComputation* computation,
117- const absl::flat_hash_map<int64_t , std::vector<HloInstruction*>>&
118- annotation_id_to_instructions) {
119- bool changed = false ;
120- for (auto & [annotation_id, sources] : annotation_id_to_instructions) {
121- absl::flat_hash_set<HloInstruction*> to_annotate =
122- PropagateAnnotationFromSources (sources, computation);
123- changed |= (!to_annotate.empty ());
124- auto status = AttachAnnotation (annotation_id, to_annotate);
125- if (!status.ok ()) {
126- return status;
127- }
128- }
129- return changed;
130- }
131-
132119absl::StatusOr<int64_t > ExtractAnnotation (
133120 const ::google::protobuf::Map<std::string, std::string>& attrs,
134121 absl::string_view instr_name) {
@@ -261,6 +248,23 @@ absl::Status CheckGapBetweenAnnotatedInstructions(
261248
262249} // namespace
263250
251+ absl::StatusOr<bool > LegalizeSchedulingAnnotations::PropagateAnnotations (
252+ const HloComputation* computation,
253+ const absl::btree_map<int64_t , std::vector<HloInstruction*>>&
254+ annotation_id_to_instructions) {
255+ bool changed = false ;
256+ for (auto & [annotation_id, sources] : annotation_id_to_instructions) {
257+ absl::flat_hash_set<HloInstruction*> to_annotate =
258+ PropagateAnnotationFromSources (sources, computation);
259+ changed |= (!to_annotate.empty ());
260+ auto status = AttachAnnotation (annotation_id, to_annotate);
261+ if (!status.ok ()) {
262+ return status;
263+ }
264+ }
265+ return changed;
266+ }
267+
264268bool LegalizeSchedulingAnnotations::KeepSchedulingAnnotation (
265269 HloInstruction* instr) {
266270 return IsSupportedAsyncOp (instr) || config_.keep_sync_annotation (instr);
@@ -348,10 +352,12 @@ absl::StatusOr<bool> LegalizeSchedulingAnnotations::Run(
348352 return false ;
349353 }
350354
351- auto status = CheckStartDoneAnnotationConsistency (annotation_to_instructions,
352- annotation);
353- if (!status.ok ()) {
354- return status;
355+ if (config_.check_start_done_annotation_consistency ) {
356+ auto status = CheckStartDoneAnnotationConsistency (
357+ annotation_to_instructions, annotation);
358+ if (!status.ok ()) {
359+ return status;
360+ }
355361 }
356362
357363 bool changed = false ;
@@ -362,7 +368,7 @@ absl::StatusOr<bool> LegalizeSchedulingAnnotations::Run(
362368 // same annotation ID.
363369 for (HloComputation* computation :
364370 module ->MakeNonfusionComputations (execution_threads)) {
365- absl::flat_hash_map <int64_t , std::vector<HloInstruction*>>
371+ absl::btree_map <int64_t , std::vector<HloInstruction*>>
366372 per_computation_annotation_to_instructions;
367373 for (const auto & [annotation_id, comp_inst_vector] :
368374 annotation_to_instructions) {
0 commit comments