@@ -64,6 +64,7 @@ namespace spmd {
6464using ::operations_research::MPConstraint;
6565using ::operations_research::MPSolver;
6666using ::operations_research::MPVariable;
67+ using EdgeAdjacency = std::vector<std::vector<EdgeIdx>>;
6768
6869// We need to nudge the maximum cost (if present) slightly, since the constraint
6970// solver cannot guarantee exact numerical precision.
@@ -935,6 +936,235 @@ EdgeStrategyIdx GetEdgeStrategy(
935936 return node_strategies[u] * num_v_strategies + node_strategies[v];
936937}
937938
939+ // Stores the active times for each node.
940+ std::vector<std::vector<LivenessIdx>> GetNodeToActiveTimes (
941+ const AutoShardingSolverRequest& request) {
942+ std::vector<std::vector<LivenessIdx>> node_to_active_times (
943+ request.num_nodes ());
944+ for (LivenessIdx t = 0 ; t < request.live_size (); ++t) {
945+ for (NodeIdx node : request.live (t).nodes ()) {
946+ node_to_active_times[node].push_back (t);
947+ }
948+ }
949+ return node_to_active_times;
950+ }
951+
952+ // Computes the memory slack for each time (i.e., budget - live memory at t)
953+ std::vector<double > TrackMemorySlack (
954+ const AutoShardingSolverRequest& request,
955+ const std::vector<NodeStrategyIdx>& node_strategies) {
956+ std::vector<double > memory_slack (request.live_size (), 0.0 );
957+ for (LivenessIdx t = 0 ; t < request.live_size (); ++t) {
958+ double live_memory = 0.0 ;
959+ for (NodeIdx node : request.live (t).nodes ()) {
960+ live_memory += request.memory_costs (node).costs (node_strategies[node]);
961+ }
962+ memory_slack[t] = request.memory_budget () - live_memory;
963+ }
964+ return memory_slack;
965+ }
966+
967+ std::pair<EdgeAdjacency, EdgeAdjacency> GetAdjacencyMatrix (
968+ const AutoShardingSolverRequest& request) {
969+ // outward_edges: i-th vector is the edges of the form (i-th node)->v.
970+ // inward_edges: i-th vector is the edges of the form v->(i-th node).
971+ EdgeAdjacency outward_edges (request.num_nodes ());
972+ EdgeAdjacency inward_edges (request.num_nodes ());
973+ for (EdgeIdx edge_idx = 0 ; edge_idx < request.edges_size (); ++edge_idx) {
974+ const auto & edge = request.edges (edge_idx);
975+ outward_edges[edge.first ()].push_back (edge_idx);
976+ inward_edges[edge.second ()].push_back (edge_idx);
977+ }
978+ return {outward_edges, inward_edges};
979+ }
980+
981+ // Store the edges within the path.
982+ std::vector<EdgeIdx> GetEdgesWithinPath (
983+ const AutoShardingSolverRequest& request, const std::vector<NodeIdx>& path,
984+ const EdgeAdjacency& outward_edges) {
985+ std::vector<EdgeIdx> edges_within_path;
986+ for (const NodeIdx& node : path) {
987+ for (const EdgeIdx& edge : outward_edges[node]) {
988+ auto it =
989+ std::find (path.begin (), path.end (), request.edges (edge).second ());
990+ if (it != path.end ()) {
991+ edges_within_path.push_back (edge);
992+ }
993+ }
994+ }
995+ return edges_within_path;
996+ }
997+
998+ // Sample a random path of length `path_length'.
999+ std::vector<NodeIdx> SamplePath (const AutoShardingSolverRequest& request,
1000+ const EdgeAdjacency& outward_edges,
1001+ const int path_length, std::mt19937_64& rng) {
1002+ std::vector<NodeIdx> path;
1003+ path.reserve (path_length + 1 );
1004+ if (path_length == 0 ) { // Sample a random node.
1005+ std::uniform_int_distribution<> dist (0 , request.num_nodes () - 1 );
1006+ path.push_back (dist (rng));
1007+ } else if (path_length == 1 ) { // Sample a random edge.
1008+ std::uniform_int_distribution<> dist (0 , request.edges_size () - 1 );
1009+ EdgeIdx random_edge_idx = dist (rng);
1010+ path.push_back (request.edges (random_edge_idx).first ());
1011+ path.push_back (request.edges (random_edge_idx).second ());
1012+ } else { // Path-sampling by concatenating nodes.
1013+ int scanned_length = 0 ;
1014+ std::uniform_int_distribution<> dist (0 , request.edges_size () - 1 );
1015+ NodeIdx u = request.edges (dist (rng)).first ();
1016+ path.push_back (u);
1017+ while (scanned_length < path_length) {
1018+ // Sample edges from the outward edges of u.
1019+ if (outward_edges[u].empty ()) {
1020+ break ;
1021+ }
1022+ scanned_length++;
1023+ std::uniform_int_distribution<> dist (0 , outward_edges[u].size () - 1 );
1024+ EdgeIdx edge_idx = outward_edges[u][dist (rng)];
1025+ u = request.edges (edge_idx).second ();
1026+ path.push_back (u);
1027+ }
1028+ }
1029+ return path;
1030+ }
1031+
1032+ // Computes the cost induced by a node and its adjacent edges.
1033+ double AggregateCostAroundNode (
1034+ const AutoShardingSolverRequest& request,
1035+ const std::pair<EdgeAdjacency, EdgeAdjacency>& adjacency,
1036+ const std::vector<NodeStrategyIdx>& node_strategies, const NodeIdx& node) {
1037+ const EdgeAdjacency& outward_edges = adjacency.first ;
1038+ const EdgeAdjacency& inward_edges = adjacency.second ;
1039+ double cost = 0.0 ;
1040+ // Node cost
1041+ cost += request.computation_costs (node).costs (node_strategies[node]) +
1042+ request.communication_costs (node).costs (node_strategies[node]);
1043+
1044+ // Edge cost
1045+ for (const EdgeIdx& outward_edge : outward_edges[node]) {
1046+ cost += request.resharding_costs (outward_edge)
1047+ .costs (GetEdgeStrategy (request, node_strategies, outward_edge));
1048+ }
1049+ for (const EdgeIdx& inward_edge : inward_edges[node]) {
1050+ cost += request.resharding_costs (inward_edge)
1051+ .costs (GetEdgeStrategy (request, node_strategies, inward_edge));
1052+ }
1053+ return cost;
1054+ }
1055+
1056+ // Computes the cost induced by a path (cost of nodes and adjacent edges).
1057+ double CostOverPath (const AutoShardingSolverRequest& request,
1058+ const std::pair<EdgeAdjacency, EdgeAdjacency>& adjacency,
1059+ const std::vector<NodeIdx>& path,
1060+ const std::vector<EdgeIdx>& edges_within_path,
1061+ std::vector<NodeStrategyIdx>& node_strategies) {
1062+ double cost = 0.0 ;
1063+ for (const NodeIdx& node : path) {
1064+ cost += AggregateCostAroundNode (request, adjacency, node_strategies, node);
1065+ }
1066+ // Subtracting the overcounted edge costs within the path.
1067+ for (const EdgeIdx& edge : edges_within_path) {
1068+ EdgeStrategyIdx edge_strategy =
1069+ GetEdgeStrategy (request, node_strategies, edge);
1070+ cost -= request.resharding_costs (edge).costs (edge_strategy);
1071+ }
1072+ return cost;
1073+ }
1074+
1075+ // Recursively optimizes over the path.
1076+ std::pair<double , std::vector<NodeStrategyIdx>> _OptimizeOverPath (
1077+ const AutoShardingSolverRequest& request, const std::vector<NodeIdx>& path,
1078+ const std::vector<EdgeIdx>& edges_within_path,
1079+ std::vector<NodeStrategyIdx>& node_strategies,
1080+ const std::pair<EdgeAdjacency, EdgeAdjacency>& adjacency,
1081+ int num_remaining_nodes) {
1082+ double best_cost = std::numeric_limits<double >::infinity ();
1083+ std::vector<NodeStrategyIdx> best_strategy (path.size (), 0 );
1084+ for (int i = 0 ; i < path.size (); ++i) {
1085+ best_strategy[i] = node_strategies[path[i]];
1086+ }
1087+
1088+ if (num_remaining_nodes == 1 ) { // Base case of the recursion.
1089+ NodeIdx last_node = path[path.size () - 1 ];
1090+ for (NodeStrategyIdx node_strategy = 0 ;
1091+ node_strategy < request.computation_costs (last_node).costs_size ();
1092+ ++node_strategy) {
1093+ node_strategies[last_node] = node_strategy;
1094+ double path_cost = CostOverPath (request, adjacency, path,
1095+ edges_within_path, node_strategies);
1096+ if (path_cost < best_cost) {
1097+ best_cost = path_cost;
1098+ best_strategy[best_strategy.size () - 1 ] = node_strategy;
1099+ }
1100+ }
1101+ } else {
1102+ NodeIdx current_node = path[path.size () - num_remaining_nodes];
1103+ for (NodeStrategyIdx node_strategy = 0 ;
1104+ node_strategy < request.computation_costs (current_node).costs_size ();
1105+ ++node_strategy) {
1106+ node_strategies[current_node] = node_strategy;
1107+ auto [path_cost, path_strategy] =
1108+ _OptimizeOverPath (request, path, edges_within_path, node_strategies,
1109+ adjacency, num_remaining_nodes - 1 );
1110+ if (path_cost < best_cost) {
1111+ best_cost = path_cost;
1112+ best_strategy = path_strategy;
1113+ }
1114+ }
1115+ }
1116+ return std::make_pair (best_cost, best_strategy);
1117+ }
1118+
1119+ // A wrapper function for `_OptimizeOverPath`, which is a recursive
1120+ // function to find the best sharding strategies for the path.
1121+ std::vector<NodeStrategyIdx> OptimizeOverPath (
1122+ const AutoShardingSolverRequest& request, const std::vector<NodeIdx>& path,
1123+ std::vector<NodeStrategyIdx>& node_strategies,
1124+ const std::pair<EdgeAdjacency, EdgeAdjacency>& adjacency) {
1125+ std::vector<NodeStrategyIdx> old_strategies (path.size (), 0 );
1126+ for (int i = 0 ; i < path.size (); ++i) {
1127+ old_strategies[i] = node_strategies[path[i]];
1128+ }
1129+ std::vector<EdgeIdx> edges_within_path =
1130+ GetEdgesWithinPath (request, path, /* outward_edges=*/ adjacency.first );
1131+
1132+ auto [_, best_path_strategies] =
1133+ _OptimizeOverPath (request, path, edges_within_path, node_strategies,
1134+ adjacency, path.size ());
1135+
1136+ // node_strategies could change within _OptimizeOverPath, so we restore the
1137+ // original sharding strategies for the nodes on the path.
1138+ for (int i = 0 ; i < path.size (); ++i) {
1139+ node_strategies[path[i]] = old_strategies[i];
1140+ }
1141+ return best_path_strategies;
1142+ }
1143+
1144+ // Check if a path's new configuration satisfies the memory constraints.
1145+ absl::flat_hash_map<LivenessIdx, double > GetNewMemorySlack (
1146+ const AutoShardingSolverRequest& request, const std::vector<NodeIdx>& path,
1147+ const std::vector<NodeStrategyIdx>& path_strategies,
1148+ const std::vector<NodeStrategyIdx>& node_strategies,
1149+ const std::vector<std::vector<LivenessIdx>>& node_to_active_times,
1150+ const std::vector<double >& memory_slack) {
1151+ absl::flat_hash_map<LivenessIdx, double > new_memory_slack;
1152+ for (int i = 0 ; i < path.size (); ++i) {
1153+ NodeIdx node = path[i];
1154+ if (!node_to_active_times[node].empty ()) {
1155+ for (LivenessIdx t : node_to_active_times[node]) {
1156+ if (!new_memory_slack.contains (t)) {
1157+ new_memory_slack[t] = memory_slack[t];
1158+ }
1159+ new_memory_slack[t] -=
1160+ (request.memory_costs (node).costs (path_strategies[i]) -
1161+ request.memory_costs (node).costs (node_strategies[node]));
1162+ }
1163+ }
1164+ }
1165+ return new_memory_slack;
1166+ }
1167+
9381168// Checks if the node-sharding strategy has a finite cost and satisfies the
9391169// peak-memory constraint.
9401170std::optional<AutoShardingViolationCode> ShardingStrategyHasViolation (
@@ -1021,8 +1251,8 @@ AutoShardingSolverOutput SolveRandom(const AutoShardingSolverRequest& request,
10211251}
10221252
10231253// Greedily selects the node sharding strategies. Valid modes:
1024- // - "node_cost "
1025- // - "node_memory "
1254+ // - "node-cost "
1255+ // - "node-memory "
10261256AutoShardingSolverOutput SolveGreedy (const AutoShardingSolverRequest& request,
10271257 const std::string& mode) {
10281258 const int num_nodes = request.num_nodes ();
@@ -1058,6 +1288,88 @@ AutoShardingSolverOutput SolveGreedy(const AutoShardingSolverRequest& request,
10581288 return output;
10591289}
10601290
1291+ // A local search algorithm that iteratively picks a random path of length
1292+ // `path_length` and computes the best sharding configuration for the path.
1293+ // - `path_length = 0` corresponds to a random node.
1294+ // - `path_length = 1` corresponds to a random edge.
1295+ // It has two `memory_mode` options for how it handles peak-memory constraints:
1296+ // - "inactive": ignores peak-memory constraints
1297+ // - "active": treats the peak-memory usage as a hard constraint
1298+ AutoShardingSolverOutput SolveRandomPathGreedy (
1299+ const AutoShardingSolverRequest& request, const int path_length,
1300+ const int num_trials, const std::string& memory_mode) {
1301+ std::mt19937_64 rng (0 );
1302+ if (memory_mode != " inactive" && memory_mode != " active" ) {
1303+ CHECK (false ) << absl::Substitute (" Memory mode $0 is not implemented." ,
1304+ memory_mode);
1305+ }
1306+
1307+ // Initialize each node's sharding strategy with the least-memory usage.
1308+ std::vector<NodeStrategyIdx> node_strategies =
1309+ SolveGreedy (request, " node-memory" ).s_val ;
1310+ const std::pair<EdgeAdjacency, EdgeAdjacency> adjacency =
1311+ GetAdjacencyMatrix (request);
1312+ std::vector<std::vector<LivenessIdx>> node_to_active_times;
1313+ std::vector<double > memory_slack;
1314+ if (memory_mode == " active" ) {
1315+ node_to_active_times = GetNodeToActiveTimes (request);
1316+ memory_slack = TrackMemorySlack (request, node_strategies);
1317+ }
1318+
1319+ for (int trial = 0 ; trial < num_trials; ++trial) {
1320+ std::vector<NodeIdx> path =
1321+ SamplePath (request, adjacency.first , path_length, rng);
1322+ if (path.size () != path_length + 1 ) {
1323+ continue ;
1324+ }
1325+ const std::vector<NodeStrategyIdx> new_path_strategies =
1326+ OptimizeOverPath (request, path, node_strategies, adjacency);
1327+
1328+ if (memory_mode == " inactive" ) {
1329+ for (int i = 0 ; i < path.size (); ++i) {
1330+ node_strategies[path[i]] = new_path_strategies[i];
1331+ }
1332+ } else if (memory_mode == " active" ) {
1333+ // Check: the new strategy over the path is different from the old one.
1334+ bool better_sharding = false ;
1335+ for (int i = 0 ; i < path.size (); ++i) {
1336+ if (node_strategies[path[i]] != new_path_strategies[i]) {
1337+ better_sharding = true ;
1338+ break ;
1339+ }
1340+ }
1341+
1342+ if (better_sharding) {
1343+ // Check: the new strategy satisfies the memory constraints.
1344+ const auto new_memory_slack_at_times = GetNewMemorySlack (
1345+ request, path, new_path_strategies, node_strategies,
1346+ node_to_active_times, memory_slack);
1347+ bool memory_feasible = true ;
1348+ for (const auto & [time_step, new_slack] : new_memory_slack_at_times) {
1349+ if (new_slack < 0.0 ) {
1350+ memory_feasible = false ;
1351+ break ;
1352+ }
1353+ }
1354+ // If feasible, update the sharding strategies and memory slack.
1355+ if (memory_feasible) {
1356+ for (const auto & [time_step, new_slack] : new_memory_slack_at_times) {
1357+ memory_slack[time_step] = new_slack;
1358+ }
1359+ for (int i = 0 ; i < path.size (); ++i) {
1360+ node_strategies[path[i]] = new_path_strategies[i];
1361+ }
1362+ }
1363+ }
1364+ }
1365+ }
1366+
1367+ AutoShardingSolverOutput output;
1368+ output.s_val = node_strategies;
1369+ output.cost = ComputeShardingStrategyCost (request, node_strategies);
1370+ return output;
1371+ }
1372+
10611373} // namespace
10621374
10631375absl::StatusOr<AutoShardingSolverOutput> RunHeuristicSolver (
@@ -1072,6 +1384,10 @@ absl::StatusOr<AutoShardingSolverOutput> RunHeuristicSolver(
10721384 output = SolveGreedy (request, " node-cost" );
10731385 } else if (algorithm == " greedy-node-memory" ) {
10741386 output = SolveGreedy (request, " node-memory" );
1387+ } else if (algorithm == " random-path-greedy" ) {
1388+ output =
1389+ SolveRandomPathGreedy (request, /* path_length=*/ 2 ,
1390+ /* num_trials=*/ 100000 , /* memory_mode=*/ " active" );
10751391 } else if (algorithm == " brkga" ) {
10761392 output = SolveBrkga (request);
10771393 } else {
0 commit comments