Skip to content

Commit 7793ab2

Browse files
Change unique_ptr to shared_ptr for latency_estimator and async_tracker in latency_hiding_scheduler, so that we can potentially reuse them across different users without needing to re-construct new ones.
PiperOrigin-RevId: 747709188
1 parent 4fbfe4f commit 7793ab2

File tree

10 files changed

+563
-34
lines changed

10 files changed

+563
-34
lines changed

xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc

Lines changed: 318 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ namespace spmd {
6464
using ::operations_research::MPConstraint;
6565
using ::operations_research::MPSolver;
6666
using ::operations_research::MPVariable;
67+
using EdgeAdjacency = std::vector<std::vector<EdgeIdx>>;
6768

6869
// We need to nudge the maximum cost (if present) slightly, since the constraint
6970
// solver cannot guarantee exact numerical precision.
@@ -935,6 +936,235 @@ EdgeStrategyIdx GetEdgeStrategy(
935936
return node_strategies[u] * num_v_strategies + node_strategies[v];
936937
}
937938

939+
// Stores the active times for each node.
940+
std::vector<std::vector<LivenessIdx>> GetNodeToActiveTimes(
941+
const AutoShardingSolverRequest& request) {
942+
std::vector<std::vector<LivenessIdx>> node_to_active_times(
943+
request.num_nodes());
944+
for (LivenessIdx t = 0; t < request.live_size(); ++t) {
945+
for (NodeIdx node : request.live(t).nodes()) {
946+
node_to_active_times[node].push_back(t);
947+
}
948+
}
949+
return node_to_active_times;
950+
}
951+
952+
// Computes the memory slack for each time (i.e., budget - live memory at t)
953+
std::vector<double> TrackMemorySlack(
954+
const AutoShardingSolverRequest& request,
955+
const std::vector<NodeStrategyIdx>& node_strategies) {
956+
std::vector<double> memory_slack(request.live_size(), 0.0);
957+
for (LivenessIdx t = 0; t < request.live_size(); ++t) {
958+
double live_memory = 0.0;
959+
for (NodeIdx node : request.live(t).nodes()) {
960+
live_memory += request.memory_costs(node).costs(node_strategies[node]);
961+
}
962+
memory_slack[t] = request.memory_budget() - live_memory;
963+
}
964+
return memory_slack;
965+
}
966+
967+
std::pair<EdgeAdjacency, EdgeAdjacency> GetAdjacencyMatrix(
968+
const AutoShardingSolverRequest& request) {
969+
// outward_edges: i-th vector is the edges of the form (i-th node)->v.
970+
// inward_edges: i-th vector is the edges of the form v->(i-th node).
971+
EdgeAdjacency outward_edges(request.num_nodes());
972+
EdgeAdjacency inward_edges(request.num_nodes());
973+
for (EdgeIdx edge_idx = 0; edge_idx < request.edges_size(); ++edge_idx) {
974+
const auto& edge = request.edges(edge_idx);
975+
outward_edges[edge.first()].push_back(edge_idx);
976+
inward_edges[edge.second()].push_back(edge_idx);
977+
}
978+
return {outward_edges, inward_edges};
979+
}
980+
981+
// Store the edges within the path.
982+
std::vector<EdgeIdx> GetEdgesWithinPath(
983+
const AutoShardingSolverRequest& request, const std::vector<NodeIdx>& path,
984+
const EdgeAdjacency& outward_edges) {
985+
std::vector<EdgeIdx> edges_within_path;
986+
for (const NodeIdx& node : path) {
987+
for (const EdgeIdx& edge : outward_edges[node]) {
988+
auto it =
989+
std::find(path.begin(), path.end(), request.edges(edge).second());
990+
if (it != path.end()) {
991+
edges_within_path.push_back(edge);
992+
}
993+
}
994+
}
995+
return edges_within_path;
996+
}
997+
998+
// Sample a random path of length `path_length'.
999+
std::vector<NodeIdx> SamplePath(const AutoShardingSolverRequest& request,
1000+
const EdgeAdjacency& outward_edges,
1001+
const int path_length, std::mt19937_64& rng) {
1002+
std::vector<NodeIdx> path;
1003+
path.reserve(path_length + 1);
1004+
if (path_length == 0) { // Sample a random node.
1005+
std::uniform_int_distribution<> dist(0, request.num_nodes() - 1);
1006+
path.push_back(dist(rng));
1007+
} else if (path_length == 1) { // Sample a random edge.
1008+
std::uniform_int_distribution<> dist(0, request.edges_size() - 1);
1009+
EdgeIdx random_edge_idx = dist(rng);
1010+
path.push_back(request.edges(random_edge_idx).first());
1011+
path.push_back(request.edges(random_edge_idx).second());
1012+
} else { // Path-sampling by concatenating nodes.
1013+
int scanned_length = 0;
1014+
std::uniform_int_distribution<> dist(0, request.edges_size() - 1);
1015+
NodeIdx u = request.edges(dist(rng)).first();
1016+
path.push_back(u);
1017+
while (scanned_length < path_length) {
1018+
// Sample edges from the outward edges of u.
1019+
if (outward_edges[u].empty()) {
1020+
break;
1021+
}
1022+
scanned_length++;
1023+
std::uniform_int_distribution<> dist(0, outward_edges[u].size() - 1);
1024+
EdgeIdx edge_idx = outward_edges[u][dist(rng)];
1025+
u = request.edges(edge_idx).second();
1026+
path.push_back(u);
1027+
}
1028+
}
1029+
return path;
1030+
}
1031+
1032+
// Computes the cost induced by a node and its adjacent edges.
1033+
double AggregateCostAroundNode(
1034+
const AutoShardingSolverRequest& request,
1035+
const std::pair<EdgeAdjacency, EdgeAdjacency>& adjacency,
1036+
const std::vector<NodeStrategyIdx>& node_strategies, const NodeIdx& node) {
1037+
const EdgeAdjacency& outward_edges = adjacency.first;
1038+
const EdgeAdjacency& inward_edges = adjacency.second;
1039+
double cost = 0.0;
1040+
// Node cost
1041+
cost += request.computation_costs(node).costs(node_strategies[node]) +
1042+
request.communication_costs(node).costs(node_strategies[node]);
1043+
1044+
// Edge cost
1045+
for (const EdgeIdx& outward_edge : outward_edges[node]) {
1046+
cost += request.resharding_costs(outward_edge)
1047+
.costs(GetEdgeStrategy(request, node_strategies, outward_edge));
1048+
}
1049+
for (const EdgeIdx& inward_edge : inward_edges[node]) {
1050+
cost += request.resharding_costs(inward_edge)
1051+
.costs(GetEdgeStrategy(request, node_strategies, inward_edge));
1052+
}
1053+
return cost;
1054+
}
1055+
1056+
// Computes the cost induced by a path (cost of nodes and adjacent edges).
1057+
double CostOverPath(const AutoShardingSolverRequest& request,
1058+
const std::pair<EdgeAdjacency, EdgeAdjacency>& adjacency,
1059+
const std::vector<NodeIdx>& path,
1060+
const std::vector<EdgeIdx>& edges_within_path,
1061+
std::vector<NodeStrategyIdx>& node_strategies) {
1062+
double cost = 0.0;
1063+
for (const NodeIdx& node : path) {
1064+
cost += AggregateCostAroundNode(request, adjacency, node_strategies, node);
1065+
}
1066+
// Subtracting the overcounted edge costs within the path.
1067+
for (const EdgeIdx& edge : edges_within_path) {
1068+
EdgeStrategyIdx edge_strategy =
1069+
GetEdgeStrategy(request, node_strategies, edge);
1070+
cost -= request.resharding_costs(edge).costs(edge_strategy);
1071+
}
1072+
return cost;
1073+
}
1074+
1075+
// Recursively optimizes over the path.
1076+
std::pair<double, std::vector<NodeStrategyIdx>> _OptimizeOverPath(
1077+
const AutoShardingSolverRequest& request, const std::vector<NodeIdx>& path,
1078+
const std::vector<EdgeIdx>& edges_within_path,
1079+
std::vector<NodeStrategyIdx>& node_strategies,
1080+
const std::pair<EdgeAdjacency, EdgeAdjacency>& adjacency,
1081+
int num_remaining_nodes) {
1082+
double best_cost = std::numeric_limits<double>::infinity();
1083+
std::vector<NodeStrategyIdx> best_strategy(path.size(), 0);
1084+
for (int i = 0; i < path.size(); ++i) {
1085+
best_strategy[i] = node_strategies[path[i]];
1086+
}
1087+
1088+
if (num_remaining_nodes == 1) { // Base case of the recursion.
1089+
NodeIdx last_node = path[path.size() - 1];
1090+
for (NodeStrategyIdx node_strategy = 0;
1091+
node_strategy < request.computation_costs(last_node).costs_size();
1092+
++node_strategy) {
1093+
node_strategies[last_node] = node_strategy;
1094+
double path_cost = CostOverPath(request, adjacency, path,
1095+
edges_within_path, node_strategies);
1096+
if (path_cost < best_cost) {
1097+
best_cost = path_cost;
1098+
best_strategy[best_strategy.size() - 1] = node_strategy;
1099+
}
1100+
}
1101+
} else {
1102+
NodeIdx current_node = path[path.size() - num_remaining_nodes];
1103+
for (NodeStrategyIdx node_strategy = 0;
1104+
node_strategy < request.computation_costs(current_node).costs_size();
1105+
++node_strategy) {
1106+
node_strategies[current_node] = node_strategy;
1107+
auto [path_cost, path_strategy] =
1108+
_OptimizeOverPath(request, path, edges_within_path, node_strategies,
1109+
adjacency, num_remaining_nodes - 1);
1110+
if (path_cost < best_cost) {
1111+
best_cost = path_cost;
1112+
best_strategy = path_strategy;
1113+
}
1114+
}
1115+
}
1116+
return std::make_pair(best_cost, best_strategy);
1117+
}
1118+
1119+
// A wrapper function for `_OptimizeOverPath`, which is a recursive
1120+
// function to find the best sharding strategies for the path.
1121+
std::vector<NodeStrategyIdx> OptimizeOverPath(
1122+
const AutoShardingSolverRequest& request, const std::vector<NodeIdx>& path,
1123+
std::vector<NodeStrategyIdx>& node_strategies,
1124+
const std::pair<EdgeAdjacency, EdgeAdjacency>& adjacency) {
1125+
std::vector<NodeStrategyIdx> old_strategies(path.size(), 0);
1126+
for (int i = 0; i < path.size(); ++i) {
1127+
old_strategies[i] = node_strategies[path[i]];
1128+
}
1129+
std::vector<EdgeIdx> edges_within_path =
1130+
GetEdgesWithinPath(request, path, /*outward_edges=*/adjacency.first);
1131+
1132+
auto [_, best_path_strategies] =
1133+
_OptimizeOverPath(request, path, edges_within_path, node_strategies,
1134+
adjacency, path.size());
1135+
1136+
// node_strategies could change within _OptimizeOverPath, so we restore the
1137+
// original sharding strategies for the nodes on the path.
1138+
for (int i = 0; i < path.size(); ++i) {
1139+
node_strategies[path[i]] = old_strategies[i];
1140+
}
1141+
return best_path_strategies;
1142+
}
1143+
1144+
// Check if a path's new configuration satisfies the memory constraints.
1145+
absl::flat_hash_map<LivenessIdx, double> GetNewMemorySlack(
1146+
const AutoShardingSolverRequest& request, const std::vector<NodeIdx>& path,
1147+
const std::vector<NodeStrategyIdx>& path_strategies,
1148+
const std::vector<NodeStrategyIdx>& node_strategies,
1149+
const std::vector<std::vector<LivenessIdx>>& node_to_active_times,
1150+
const std::vector<double>& memory_slack) {
1151+
absl::flat_hash_map<LivenessIdx, double> new_memory_slack;
1152+
for (int i = 0; i < path.size(); ++i) {
1153+
NodeIdx node = path[i];
1154+
if (!node_to_active_times[node].empty()) {
1155+
for (LivenessIdx t : node_to_active_times[node]) {
1156+
if (!new_memory_slack.contains(t)) {
1157+
new_memory_slack[t] = memory_slack[t];
1158+
}
1159+
new_memory_slack[t] -=
1160+
(request.memory_costs(node).costs(path_strategies[i]) -
1161+
request.memory_costs(node).costs(node_strategies[node]));
1162+
}
1163+
}
1164+
}
1165+
return new_memory_slack;
1166+
}
1167+
9381168
// Checks if the node-sharding strategy has a finite cost and satisfies the
9391169
// peak-memory constraint.
9401170
std::optional<AutoShardingViolationCode> ShardingStrategyHasViolation(
@@ -1021,8 +1251,8 @@ AutoShardingSolverOutput SolveRandom(const AutoShardingSolverRequest& request,
10211251
}
10221252

10231253
// Greedily selects the node sharding strategies. Valid modes:
1024-
// - "node_cost"
1025-
// - "node_memory"
1254+
// - "node-cost"
1255+
// - "node-memory"
10261256
AutoShardingSolverOutput SolveGreedy(const AutoShardingSolverRequest& request,
10271257
const std::string& mode) {
10281258
const int num_nodes = request.num_nodes();
@@ -1058,6 +1288,88 @@ AutoShardingSolverOutput SolveGreedy(const AutoShardingSolverRequest& request,
10581288
return output;
10591289
}
10601290

1291+
// A local search algorithm that iteratively picks a random path of length
1292+
// `path_length` and computes the best sharding configuration for the path.
1293+
// - `path_length = 0` corresponds to a random node.
1294+
// - `path_length = 1` corresponds to a random edge.
1295+
// It has two `memory_mode` options for how it handles peak-memory constraints:
1296+
// - "inactive": ignores peak-memory constraints
1297+
// - "active": treats the peak-memory usage as a hard constraint
1298+
AutoShardingSolverOutput SolveRandomPathGreedy(
1299+
const AutoShardingSolverRequest& request, const int path_length,
1300+
const int num_trials, const std::string& memory_mode) {
1301+
std::mt19937_64 rng(0);
1302+
if (memory_mode != "inactive" && memory_mode != "active") {
1303+
CHECK(false) << absl::Substitute("Memory mode $0 is not implemented.",
1304+
memory_mode);
1305+
}
1306+
1307+
// Initialize each node's sharding strategy with the least-memory usage.
1308+
std::vector<NodeStrategyIdx> node_strategies =
1309+
SolveGreedy(request, "node-memory").s_val;
1310+
const std::pair<EdgeAdjacency, EdgeAdjacency> adjacency =
1311+
GetAdjacencyMatrix(request);
1312+
std::vector<std::vector<LivenessIdx>> node_to_active_times;
1313+
std::vector<double> memory_slack;
1314+
if (memory_mode == "active") {
1315+
node_to_active_times = GetNodeToActiveTimes(request);
1316+
memory_slack = TrackMemorySlack(request, node_strategies);
1317+
}
1318+
1319+
for (int trial = 0; trial < num_trials; ++trial) {
1320+
std::vector<NodeIdx> path =
1321+
SamplePath(request, adjacency.first, path_length, rng);
1322+
if (path.size() != path_length + 1) {
1323+
continue;
1324+
}
1325+
const std::vector<NodeStrategyIdx> new_path_strategies =
1326+
OptimizeOverPath(request, path, node_strategies, adjacency);
1327+
1328+
if (memory_mode == "inactive") {
1329+
for (int i = 0; i < path.size(); ++i) {
1330+
node_strategies[path[i]] = new_path_strategies[i];
1331+
}
1332+
} else if (memory_mode == "active") {
1333+
// Check: the new strategy over the path is different from the old one.
1334+
bool better_sharding = false;
1335+
for (int i = 0; i < path.size(); ++i) {
1336+
if (node_strategies[path[i]] != new_path_strategies[i]) {
1337+
better_sharding = true;
1338+
break;
1339+
}
1340+
}
1341+
1342+
if (better_sharding) {
1343+
// Check: the new strategy satisfies the memory constraints.
1344+
const auto new_memory_slack_at_times = GetNewMemorySlack(
1345+
request, path, new_path_strategies, node_strategies,
1346+
node_to_active_times, memory_slack);
1347+
bool memory_feasible = true;
1348+
for (const auto& [time_step, new_slack] : new_memory_slack_at_times) {
1349+
if (new_slack < 0.0) {
1350+
memory_feasible = false;
1351+
break;
1352+
}
1353+
}
1354+
// If feasible, update the sharding strategies and memory slack.
1355+
if (memory_feasible) {
1356+
for (const auto& [time_step, new_slack] : new_memory_slack_at_times) {
1357+
memory_slack[time_step] = new_slack;
1358+
}
1359+
for (int i = 0; i < path.size(); ++i) {
1360+
node_strategies[path[i]] = new_path_strategies[i];
1361+
}
1362+
}
1363+
}
1364+
}
1365+
}
1366+
1367+
AutoShardingSolverOutput output;
1368+
output.s_val = node_strategies;
1369+
output.cost = ComputeShardingStrategyCost(request, node_strategies);
1370+
return output;
1371+
}
1372+
10611373
} // namespace
10621374

10631375
absl::StatusOr<AutoShardingSolverOutput> RunHeuristicSolver(
@@ -1072,6 +1384,10 @@ absl::StatusOr<AutoShardingSolverOutput> RunHeuristicSolver(
10721384
output = SolveGreedy(request, "node-cost");
10731385
} else if (algorithm == "greedy-node-memory") {
10741386
output = SolveGreedy(request, "node-memory");
1387+
} else if (algorithm == "random-path-greedy") {
1388+
output =
1389+
SolveRandomPathGreedy(request, /*path_length=*/2,
1390+
/*num_trials=*/100000, /*memory_mode=*/"active");
10751391
} else if (algorithm == "brkga") {
10761392
output = SolveBrkga(request);
10771393
} else {

xla/service/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1182,7 +1182,6 @@ cc_library(
11821182
"//xla/hlo/analysis:hlo_alias_analysis",
11831183
"//xla/hlo/analysis:hlo_reachability",
11841184
"//xla/hlo/ir:hlo",
1185-
"//xla/hlo/ir:ptrvec",
11861185
"//xla/hlo/pass:hlo_pass",
11871186
"//xla/tsl/platform:errors",
11881187
"//xla/tsl/platform:statusor",
@@ -5858,6 +5857,7 @@ cc_library(
58585857
"//xla/hlo/ir:ptrvec",
58595858
"//xla/hlo/pass:hlo_pass",
58605859
"//xla/tsl/platform:statusor",
5860+
"@com_google_absl//absl/container:btree",
58615861
"@com_google_absl//absl/container:flat_hash_map",
58625862
"@com_google_absl//absl/container:flat_hash_set",
58635863
"@com_google_absl//absl/log",

xla/service/latency_hiding_scheduler.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2926,15 +2926,16 @@ absl::StatusOr<bool> LatencyHidingScheduler::Run(
29262926
saved_schedules[computation] = std::move(new_schedule);
29272927
}
29282928
}
2929-
LOG(INFO) << "LatencyHidingScheduler current memory usage: "
2929+
LOG(INFO) << "[" << name() << "]"
2930+
<< " LatencyHidingScheduler current memory usage: "
29302931
<< scheduler_core_->GetMemoryPeak()
29312932
<< " bytes. Current limit: " << scheduler_core_->GetMemoryLimit();
29322933
for (HloComputation* computation : computations_to_schedule) {
2933-
VLOG(1) << "Statistics before scheduling:";
2934+
VLOG(1) << "[" << name() << "] Statistics before scheduling:";
29342935
LogScheduleStatistics(computation);
29352936
module->schedule().set_sequence(
29362937
computation, absl::MakeConstSpan(saved_schedules[computation]));
2937-
VLOG(1) << "Statistics after scheduling:";
2938+
VLOG(1) << "[" << name() << "] Statistics after scheduling:";
29382939
LogScheduleStatistics(computation);
29392940
}
29402941
if (debug_options.xla_dump_latency_hiding_schedule()) {

0 commit comments

Comments
 (0)