Skip to content

Commit 6b5a7a1

Browse files
Change unique_ptr to shared_ptr for latency_estimator and async_tracker in latency_hiding_scheduler, so that we can potentially reuse them across different users without needing to re-construct new ones.
FUTURE_COPYBARA_INTEGRATE_REVIEW=#25166 from openxla:add_slow_arg_init_alarm faa4fd1 PiperOrigin-RevId: 747709188
1 parent f3b9c5a commit 6b5a7a1

File tree

7 files changed

+58
-33
lines changed

7 files changed

+58
-33
lines changed

xla/service/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1182,7 +1182,6 @@ cc_library(
11821182
"//xla/hlo/analysis:hlo_alias_analysis",
11831183
"//xla/hlo/analysis:hlo_reachability",
11841184
"//xla/hlo/ir:hlo",
1185-
"//xla/hlo/ir:ptrvec",
11861185
"//xla/hlo/pass:hlo_pass",
11871186
"//xla/tsl/platform:errors",
11881187
"//xla/tsl/platform:statusor",
@@ -5858,6 +5857,7 @@ cc_library(
58585857
"//xla/hlo/ir:ptrvec",
58595858
"//xla/hlo/pass:hlo_pass",
58605859
"//xla/tsl/platform:statusor",
5860+
"@com_google_absl//absl/container:btree",
58615861
"@com_google_absl//absl/container:flat_hash_map",
58625862
"@com_google_absl//absl/container:flat_hash_set",
58635863
"@com_google_absl//absl/log",

xla/service/latency_hiding_scheduler.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2926,15 +2926,16 @@ absl::StatusOr<bool> LatencyHidingScheduler::Run(
29262926
saved_schedules[computation] = std::move(new_schedule);
29272927
}
29282928
}
2929-
LOG(INFO) << "LatencyHidingScheduler current memory usage: "
2929+
LOG(INFO) << "[" << name() << "]"
2930+
<< " LatencyHidingScheduler current memory usage: "
29302931
<< scheduler_core_->GetMemoryPeak()
29312932
<< " bytes. Current limit: " << scheduler_core_->GetMemoryLimit();
29322933
for (HloComputation* computation : computations_to_schedule) {
2933-
VLOG(1) << "Statistics before scheduling:";
2934+
VLOG(1) << "[" << name() << "] Statistics before scheduling:";
29342935
LogScheduleStatistics(computation);
29352936
module->schedule().set_sequence(
29362937
computation, absl::MakeConstSpan(saved_schedules[computation]));
2937-
VLOG(1) << "Statistics after scheduling:";
2938+
VLOG(1) << "[" << name() << "] Statistics after scheduling:";
29382939
LogScheduleStatistics(computation);
29392940
}
29402941
if (debug_options.xla_dump_latency_hiding_schedule()) {

xla/service/latency_hiding_scheduler.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ limitations under the License.
4343
#include "xla/hlo/ir/hlo_instruction.h"
4444
#include "xla/hlo/ir/hlo_opcode.h"
4545
#include "xla/hlo/ir/hlo_schedule.h"
46-
#include "xla/hlo/ir/ptrvec.h"
4746
#include "xla/hlo/pass/hlo_pass_interface.h"
4847
#include "xla/map_util.h"
4948
#include "xla/service/hlo_buffer.h"
@@ -1224,8 +1223,8 @@ class LatencyHidingScheduler : public HloModulePass {
12241223
};
12251224

12261225
LatencyHidingScheduler(
1227-
std::unique_ptr<LatencyEstimator> latency_estimator,
1228-
std::unique_ptr<AsyncTracker> async_tracker,
1226+
std::shared_ptr<LatencyEstimator> latency_estimator,
1227+
std::shared_ptr<AsyncTracker> async_tracker,
12291228
std::unique_ptr<SchedulerCore> scheduler_core,
12301229
const HloCostAnalysis::ShapeSizeFunction& shape_size_bytes)
12311230
: latency_estimator_(std::move(latency_estimator)),
@@ -1252,8 +1251,8 @@ class LatencyHidingScheduler : public HloModulePass {
12521251
virtual void LogScheduleStatistics(const HloComputation* computation);
12531252

12541253
private:
1255-
std::unique_ptr<LatencyEstimator> latency_estimator_;
1256-
std::unique_ptr<AsyncTracker> async_tracker_;
1254+
std::shared_ptr<LatencyEstimator> latency_estimator_;
1255+
std::shared_ptr<AsyncTracker> async_tracker_;
12571256
std::unique_ptr<SchedulerCore> scheduler_core_;
12581257
const HloCostAnalysis::ShapeSizeFunction shape_size_bytes_;
12591258
absl::flat_hash_set<HloComputation*> computations_to_schedule_;

xla/service/legalize_scheduling_annotations.cc

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#include <string>
2222
#include <vector>
2323

24+
#include "absl/container/btree_map.h"
2425
#include "absl/container/flat_hash_map.h"
2526
#include "absl/container/flat_hash_set.h"
2627
#include "absl/log/check.h"
@@ -48,7 +49,8 @@ namespace {
4849
// Given a group of annotated instructions (sources), find all reachable
4950
// instructions from them in the same computation.
5051
absl::flat_hash_set<HloInstruction*> PropagateAnnotationFromSources(
51-
const std::vector<HloInstruction*>& sources, HloComputation* computation) {
52+
const std::vector<HloInstruction*>& sources,
53+
const HloComputation* computation) {
5254
absl::flat_hash_set<HloInstruction*> to_annotate;
5355
auto reachability = HloReachabilityMap::Build(computation);
5456
// worklist contains instructions that can reach any source instruction.
@@ -107,28 +109,13 @@ absl::Status AttachAnnotation(
107109
absl::StrCat(annotation_id) + " to " + std::string(instr->name()) +
108110
" but it has an existing annotation: " + absl::StrCat(*id));
109111
}
112+
LOG(INFO) << "Propagating annotation " << annotation_id << " to "
113+
<< instr->name();
110114
SetSchedulingAnnotation(instr, annotation_id);
111115
}
112116
return absl::OkStatus();
113117
}
114118

115-
absl::StatusOr<bool> PropagateAnnotations(
116-
HloComputation* computation,
117-
const absl::flat_hash_map<int64_t, std::vector<HloInstruction*>>&
118-
annotation_id_to_instructions) {
119-
bool changed = false;
120-
for (auto& [annotation_id, sources] : annotation_id_to_instructions) {
121-
absl::flat_hash_set<HloInstruction*> to_annotate =
122-
PropagateAnnotationFromSources(sources, computation);
123-
changed |= (!to_annotate.empty());
124-
auto status = AttachAnnotation(annotation_id, to_annotate);
125-
if (!status.ok()) {
126-
return status;
127-
}
128-
}
129-
return changed;
130-
}
131-
132119
absl::StatusOr<int64_t> ExtractAnnotation(
133120
const ::google::protobuf::Map<std::string, std::string>& attrs,
134121
absl::string_view instr_name) {
@@ -261,6 +248,23 @@ absl::Status CheckGapBetweenAnnotatedInstructions(
261248

262249
} // namespace
263250

251+
absl::StatusOr<bool> LegalizeSchedulingAnnotations::PropagateAnnotations(
252+
const HloComputation* computation,
253+
const absl::btree_map<int64_t, std::vector<HloInstruction*>>&
254+
annotation_id_to_instructions) {
255+
bool changed = false;
256+
for (auto& [annotation_id, sources] : annotation_id_to_instructions) {
257+
absl::flat_hash_set<HloInstruction*> to_annotate =
258+
PropagateAnnotationFromSources(sources, computation);
259+
changed |= (!to_annotate.empty());
260+
auto status = AttachAnnotation(annotation_id, to_annotate);
261+
if (!status.ok()) {
262+
return status;
263+
}
264+
}
265+
return changed;
266+
}
267+
264268
bool LegalizeSchedulingAnnotations::KeepSchedulingAnnotation(
265269
HloInstruction* instr) {
266270
return IsSupportedAsyncOp(instr) || config_.keep_sync_annotation(instr);
@@ -348,10 +352,12 @@ absl::StatusOr<bool> LegalizeSchedulingAnnotations::Run(
348352
return false;
349353
}
350354

351-
auto status = CheckStartDoneAnnotationConsistency(annotation_to_instructions,
352-
annotation);
353-
if (!status.ok()) {
354-
return status;
355+
if (config_.check_start_done_annotation_consistency) {
356+
auto status = CheckStartDoneAnnotationConsistency(
357+
annotation_to_instructions, annotation);
358+
if (!status.ok()) {
359+
return status;
360+
}
355361
}
356362

357363
bool changed = false;
@@ -362,7 +368,7 @@ absl::StatusOr<bool> LegalizeSchedulingAnnotations::Run(
362368
// same annotation ID.
363369
for (HloComputation* computation :
364370
module->MakeNonfusionComputations(execution_threads)) {
365-
absl::flat_hash_map<int64_t, std::vector<HloInstruction*>>
371+
absl::btree_map<int64_t, std::vector<HloInstruction*>>
366372
per_computation_annotation_to_instructions;
367373
for (const auto& [annotation_id, comp_inst_vector] :
368374
annotation_to_instructions) {

xla/service/legalize_scheduling_annotations.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,16 @@ limitations under the License.
1616
#ifndef XLA_SERVICE_LEGALIZE_SCHEDULING_ANNOTATIONS_H_
1717
#define XLA_SERVICE_LEGALIZE_SCHEDULING_ANNOTATIONS_H_
1818

19+
#include <cstdint>
1920
#include <utility>
21+
#include <vector>
2022

23+
#include "absl/container/btree_map.h"
24+
#include "absl/container/flat_hash_map.h"
2125
#include "absl/container/flat_hash_set.h"
2226
#include "absl/status/statusor.h"
2327
#include "absl/strings/string_view.h"
28+
#include "xla/hlo/ir/hlo_computation.h"
2429
#include "xla/hlo/ir/hlo_module.h"
2530
#include "xla/hlo/pass/hlo_pass_interface.h"
2631
#include "xla/util.h"
@@ -34,13 +39,20 @@ class LegalizeSchedulingAnnotations : public HloModulePass {
3439
struct Config {
3540
HloPredicate keep_sync_annotation = HloPredicateTrue;
3641
bool propagate_annotation = false;
42+
bool check_start_done_annotation_consistency = true;
3743
};
3844

3945
explicit LegalizeSchedulingAnnotations(Config config)
4046
: config_(std::move(config)) {}
4147
absl::string_view name() const override {
4248
return "legalize-scheduling-annotations";
4349
}
50+
51+
static absl::StatusOr<bool> PropagateAnnotations(
52+
const HloComputation* computation,
53+
const absl::btree_map<int64_t, std::vector<HloInstruction*>>&
54+
annotation_id_to_instructions);
55+
4456
using HloPassInterface::Run;
4557
absl::StatusOr<bool> Run(
4658
HloModule* module,

xla/tools/multihost_hlo_runner/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ cc_library(
168168
"//xla/service:hlo_module_config",
169169
"//xla/service:hlo_module_util",
170170
"//xla/service:hlo_proto_cc",
171+
"//xla/service:slow_operation_alarm",
171172
"//xla/tests:test_utils",
172173
"//xla/tools:hlo_control_flow_flattening",
173174
"//xla/tsl/platform:env",

xla/tools/multihost_hlo_runner/functional_hlo_runner.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ limitations under the License.
6565
#include "xla/service/hlo.pb.h"
6666
#include "xla/service/hlo_module_config.h"
6767
#include "xla/service/hlo_module_util.h"
68+
#include "xla/service/slow_operation_alarm.h"
6869
#include "xla/shape_layout.h"
6970
#include "xla/shape_util.h"
7071
#include "xla/status_macros.h"
@@ -1161,6 +1162,11 @@ FunctionalHloRunner::CreateArgumentsOnDevice(
11611162
client, executable, running_options, flatten_arguments);
11621163
}
11631164

1165+
SlowOperationAlarm alarm(
1166+
absl::Seconds(5),
1167+
absl::StrFormat("Argument initialization is slow. Consider changing "
1168+
"--hlo_argument_mode."));
1169+
11641170
absl::Span<PjRtDevice* const> addressable_devices =
11651171
executable->addressable_devices();
11661172
size_t num_addressable_devices = addressable_devices.size();
@@ -1185,7 +1191,7 @@ FunctionalHloRunner::CreateArgumentsOnDevice(
11851191
ModuleArgumentMode::kUseZerosAsInput;
11861192

11871193
for (int i = 0; i < num_addressable_devices; ++i) {
1188-
VLOG(3) << "Creating fake argument for device " << i;
1194+
VLOG(3) << "Creating fake arguments for device " << i;
11891195
LiteralVec& argument_literals =
11901196
per_device_argument_literals[addressable_devices[i]->id()];
11911197
int executable_idx = hlo_modules.size() == 1

0 commit comments

Comments
 (0)