@@ -802,6 +802,7 @@ def combine(self, hidden_states) -> torch.Tensor:
802802
803803
804804_WORLD : Optional [GroupCoordinator ] = None
805+ _NODE_COUNT : Optional [int ] = None
805806
806807
807808def get_world_group () -> GroupCoordinator :
@@ -961,10 +962,13 @@ def init_distributed_environment(
961962 local_rank = envs .LOCAL_RANK
962963 else :
963964 local_rank = rank
964- global _WORLD
965+ global _WORLD , _NODE_COUNT
965966 if _WORLD is None :
966967 ranks = list (range (torch .distributed .get_world_size ()))
967968 _WORLD = init_world_group (ranks , local_rank , backend )
969+ _NODE_COUNT = _node_count (_WORLD .cpu_group )
970+ logger .debug ("Detected %d nodes in the distributed environment" ,
971+ _NODE_COUNT )
968972 else :
969973 assert _WORLD .world_size == torch .distributed .get_world_size (), (
970974 "world group already initialized with a different world size" )
@@ -1164,6 +1168,13 @@ def get_tensor_model_parallel_rank():
11641168 return get_tp_group ().rank_in_group
11651169
11661170
1171+ def get_node_count () -> int :
1172+ """Return the total number of nodes in the distributed environment. """
1173+ assert _NODE_COUNT is not None , (
1174+ "distributed environment is not initialized" )
1175+ return _NODE_COUNT
1176+
1177+
11671178def destroy_model_parallel ():
11681179 """Set the groups to none and destroy them."""
11691180 global _TP
@@ -1189,10 +1200,11 @@ def destroy_model_parallel():
11891200
11901201
11911202def destroy_distributed_environment ():
1192- global _WORLD
1203+ global _WORLD , _NODE_COUNT
11931204 if _WORLD :
11941205 _WORLD .destroy ()
11951206 _WORLD = None
1207+ _NODE_COUNT = None
11961208 if torch .distributed .is_initialized ():
11971209 torch .distributed .destroy_process_group ()
11981210
@@ -1301,3 +1313,42 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
13011313 aggregated_data += rank_data
13021314
13031315 return [x == 1 for x in aggregated_data .tolist ()]
1316+
1317+
1318+ def _node_count (pg : Union [ProcessGroup , StatelessProcessGroup ]) -> int :
1319+ """
1320+ Returns the total number of nodes in the process group.
1321+
1322+ Args:
1323+ pg: The process group to analyze
1324+
1325+ Returns:
1326+ int: The total number of nodes
1327+ """
1328+ if isinstance (pg , ProcessGroup ):
1329+ world_size = torch .distributed .get_world_size (group = pg )
1330+ else :
1331+ world_size = pg .world_size
1332+
1333+ if world_size == 1 :
1334+ return 1
1335+
1336+ # Build node assignment map
1337+ node_assignment = [0 ] * world_size # rank -> node_id
1338+ next_node_id = 0
1339+
1340+ for current_rank in range (world_size ):
1341+ if node_assignment [current_rank ] != 0 :
1342+ continue # Already assigned to a node
1343+
1344+ # Assign current rank to a new node
1345+ next_node_id += 1
1346+ node_assignment [current_rank ] = next_node_id
1347+
1348+ # Find all ranks on the same node as current_rank
1349+ same_node_flags = in_the_same_node_as (pg , current_rank )
1350+ for other_rank , is_same_node in enumerate (same_node_flags ):
1351+ if is_same_node and node_assignment [other_rank ] == 0 :
1352+ node_assignment [other_rank ] = next_node_id
1353+
1354+ return next_node_id
0 commit comments