Skip to content

Commit c3a8fb1

Browse files
Change unique_ptr to shared_ptr for latency_estimator and async_tracker in latency_hiding_scheduler, so that we can potentially reuse them across different users without needing to re-construct new ones.
PiperOrigin-RevId: 750398930
1 parent cfc1eaf commit c3a8fb1

File tree

5 files changed

+50
-32
lines changed

5 files changed

+50
-32
lines changed

xla/service/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1171,7 +1171,6 @@ cc_library(
11711171
"//xla/hlo/analysis:hlo_alias_analysis",
11721172
"//xla/hlo/analysis:hlo_reachability",
11731173
"//xla/hlo/ir:hlo",
1174-
"//xla/hlo/ir:ptrvec",
11751174
"//xla/hlo/pass:hlo_pass",
11761175
"//xla/tsl/platform:errors",
11771176
"//xla/tsl/platform:statusor",
@@ -5837,6 +5836,7 @@ cc_library(
58375836
"//xla/hlo/ir:ptrvec",
58385837
"//xla/hlo/pass:hlo_pass",
58395838
"//xla/tsl/platform:statusor",
5839+
"@com_google_absl//absl/container:btree",
58405840
"@com_google_absl//absl/container:flat_hash_map",
58415841
"@com_google_absl//absl/container:flat_hash_set",
58425842
"@com_google_absl//absl/log",

xla/service/latency_hiding_scheduler.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2926,15 +2926,16 @@ absl::StatusOr<bool> LatencyHidingScheduler::Run(
29262926
saved_schedules[computation] = std::move(new_schedule);
29272927
}
29282928
}
2929-
LOG(INFO) << "LatencyHidingScheduler current memory usage: "
2929+
LOG(INFO) << "[" << name() << "]"
2930+
<< " LatencyHidingScheduler current memory usage: "
29302931
<< scheduler_core_->GetMemoryPeak()
29312932
<< " bytes. Current limit: " << scheduler_core_->GetMemoryLimit();
29322933
for (HloComputation* computation : computations_to_schedule) {
2933-
VLOG(1) << "Statistics before scheduling:";
2934+
VLOG(1) << "[" << name() << "] Statistics before scheduling:";
29342935
LogScheduleStatistics(computation);
29352936
module->schedule().set_sequence(
29362937
computation, absl::MakeConstSpan(saved_schedules[computation]));
2937-
VLOG(1) << "Statistics after scheduling:";
2938+
VLOG(1) << "[" << name() << "] Statistics after scheduling:";
29382939
LogScheduleStatistics(computation);
29392940
}
29402941
if (debug_options.xla_dump_latency_hiding_schedule()) {

xla/service/latency_hiding_scheduler.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ limitations under the License.
4343
#include "xla/hlo/ir/hlo_instruction.h"
4444
#include "xla/hlo/ir/hlo_opcode.h"
4545
#include "xla/hlo/ir/hlo_schedule.h"
46-
#include "xla/hlo/ir/ptrvec.h"
4746
#include "xla/hlo/pass/hlo_pass_interface.h"
4847
#include "xla/map_util.h"
4948
#include "xla/service/hlo_buffer.h"
@@ -1224,8 +1223,8 @@ class LatencyHidingScheduler : public HloModulePass {
12241223
};
12251224

12261225
LatencyHidingScheduler(
1227-
std::unique_ptr<LatencyEstimator> latency_estimator,
1228-
std::unique_ptr<AsyncTracker> async_tracker,
1226+
std::shared_ptr<LatencyEstimator> latency_estimator,
1227+
std::shared_ptr<AsyncTracker> async_tracker,
12291228
std::unique_ptr<SchedulerCore> scheduler_core,
12301229
const HloCostAnalysis::ShapeSizeFunction& shape_size_bytes)
12311230
: latency_estimator_(std::move(latency_estimator)),
@@ -1252,8 +1251,8 @@ class LatencyHidingScheduler : public HloModulePass {
12521251
virtual void LogScheduleStatistics(const HloComputation* computation);
12531252

12541253
private:
1255-
std::unique_ptr<LatencyEstimator> latency_estimator_;
1256-
std::unique_ptr<AsyncTracker> async_tracker_;
1254+
std::shared_ptr<LatencyEstimator> latency_estimator_;
1255+
std::shared_ptr<AsyncTracker> async_tracker_;
12571256
std::unique_ptr<SchedulerCore> scheduler_core_;
12581257
const HloCostAnalysis::ShapeSizeFunction shape_size_bytes_;
12591258
absl::flat_hash_set<HloComputation*> computations_to_schedule_;

xla/service/legalize_scheduling_annotations.cc

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#include <string>
2222
#include <vector>
2323

24+
#include "absl/container/btree_map.h"
2425
#include "absl/container/flat_hash_map.h"
2526
#include "absl/container/flat_hash_set.h"
2627
#include "absl/log/check.h"
@@ -48,7 +49,8 @@ namespace {
4849
// Given a group of annotated instructions (sources), find all reachable
4950
// instructions from them in the same computation.
5051
absl::flat_hash_set<HloInstruction*> PropagateAnnotationFromSources(
51-
const std::vector<HloInstruction*>& sources, HloComputation* computation) {
52+
const std::vector<HloInstruction*>& sources,
53+
const HloComputation* computation) {
5254
absl::flat_hash_set<HloInstruction*> to_annotate;
5355
auto reachability = HloReachabilityMap::Build(computation);
5456
// worklist contains instructions that can reach any source instruction.
@@ -107,28 +109,13 @@ absl::Status AttachAnnotation(
107109
absl::StrCat(annotation_id) + " to " + std::string(instr->name()) +
108110
" but it has an existing annotation: " + absl::StrCat(*id));
109111
}
112+
LOG(INFO) << "Propagating annotation " << annotation_id << " to "
113+
<< instr->name();
110114
SetSchedulingAnnotation(instr, annotation_id);
111115
}
112116
return absl::OkStatus();
113117
}
114118

115-
absl::StatusOr<bool> PropagateAnnotations(
116-
HloComputation* computation,
117-
const absl::flat_hash_map<int64_t, std::vector<HloInstruction*>>&
118-
annotation_id_to_instructions) {
119-
bool changed = false;
120-
for (auto& [annotation_id, sources] : annotation_id_to_instructions) {
121-
absl::flat_hash_set<HloInstruction*> to_annotate =
122-
PropagateAnnotationFromSources(sources, computation);
123-
changed |= (!to_annotate.empty());
124-
auto status = AttachAnnotation(annotation_id, to_annotate);
125-
if (!status.ok()) {
126-
return status;
127-
}
128-
}
129-
return changed;
130-
}
131-
132119
absl::StatusOr<int64_t> ExtractAnnotation(
133120
const ::google::protobuf::Map<std::string, std::string>& attrs,
134121
absl::string_view instr_name) {
@@ -261,6 +248,23 @@ absl::Status CheckGapBetweenAnnotatedInstructions(
261248

262249
} // namespace
263250

251+
absl::StatusOr<bool> LegalizeSchedulingAnnotations::PropagateAnnotations(
252+
const HloComputation* computation,
253+
const absl::btree_map<int64_t, std::vector<HloInstruction*>>&
254+
annotation_id_to_instructions) {
255+
bool changed = false;
256+
for (auto& [annotation_id, sources] : annotation_id_to_instructions) {
257+
absl::flat_hash_set<HloInstruction*> to_annotate =
258+
PropagateAnnotationFromSources(sources, computation);
259+
changed |= (!to_annotate.empty());
260+
auto status = AttachAnnotation(annotation_id, to_annotate);
261+
if (!status.ok()) {
262+
return status;
263+
}
264+
}
265+
return changed;
266+
}
267+
264268
bool LegalizeSchedulingAnnotations::KeepSchedulingAnnotation(
265269
HloInstruction* instr) {
266270
const auto& attrs = instr->frontend_attributes().map();
@@ -354,10 +358,12 @@ absl::StatusOr<bool> LegalizeSchedulingAnnotations::Run(
354358
return false;
355359
}
356360

357-
auto status = CheckStartDoneAnnotationConsistency(annotation_to_instructions,
358-
annotation);
359-
if (!status.ok()) {
360-
return status;
361+
if (config_.check_start_done_annotation_consistency) {
362+
auto status = CheckStartDoneAnnotationConsistency(
363+
annotation_to_instructions, annotation);
364+
if (!status.ok()) {
365+
return status;
366+
}
361367
}
362368

363369
bool changed = false;
@@ -368,7 +374,7 @@ absl::StatusOr<bool> LegalizeSchedulingAnnotations::Run(
368374
// same annotation ID.
369375
for (HloComputation* computation :
370376
module->MakeNonfusionComputations(execution_threads)) {
371-
absl::flat_hash_map<int64_t, std::vector<HloInstruction*>>
377+
absl::btree_map<int64_t, std::vector<HloInstruction*>>
372378
per_computation_annotation_to_instructions;
373379
for (const auto& [annotation_id, comp_inst_vector] :
374380
annotation_to_instructions) {

xla/service/legalize_scheduling_annotations.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,16 @@ limitations under the License.
1616
#ifndef XLA_SERVICE_LEGALIZE_SCHEDULING_ANNOTATIONS_H_
1717
#define XLA_SERVICE_LEGALIZE_SCHEDULING_ANNOTATIONS_H_
1818

19+
#include <cstdint>
1920
#include <utility>
21+
#include <vector>
2022

23+
#include "absl/container/btree_map.h"
24+
#include "absl/container/flat_hash_map.h"
2125
#include "absl/container/flat_hash_set.h"
2226
#include "absl/status/statusor.h"
2327
#include "absl/strings/string_view.h"
28+
#include "xla/hlo/ir/hlo_computation.h"
2429
#include "xla/hlo/ir/hlo_module.h"
2530
#include "xla/hlo/pass/hlo_pass_interface.h"
2631
#include "xla/util.h"
@@ -34,13 +39,20 @@ class LegalizeSchedulingAnnotations : public HloModulePass {
3439
struct Config {
3540
HloPredicate keep_sync_annotation = HloPredicateTrue;
3641
bool propagate_annotation = false;
42+
bool check_start_done_annotation_consistency = true;
3743
};
3844

3945
explicit LegalizeSchedulingAnnotations(Config config)
4046
: config_(std::move(config)) {}
4147
absl::string_view name() const override {
4248
return "legalize-scheduling-annotations";
4349
}
50+
51+
static absl::StatusOr<bool> PropagateAnnotations(
52+
const HloComputation* computation,
53+
const absl::btree_map<int64_t, std::vector<HloInstruction*>>&
54+
annotation_id_to_instructions);
55+
4456
using HloPassInterface::Run;
4557
absl::StatusOr<bool> Run(
4658
HloModule* module,

0 commit comments

Comments
 (0)