@@ -897,29 +897,22 @@ def initialize_model_parallel(
897897 get_world_group ().device_group )
898898
899899 data_parallel_size = 1
900- has_external_dp = False
901900 from vllm .config import get_current_vllm_config
902901 config = get_current_vllm_config ()
903902 if config is not None :
904- if config .parallel_config .world_size != world_size :
905- # detect external data parallelism.
906- # dp in vllm means all dp instances need to run together.
907- # if the world size does not match, it means this dp is external,
908- # and the dp instances can run independently, e.g. in rlhf workflow
909- # from https://github.com/volcengine/verl .
910- # in that case, we treat the rest dimensions as if they are
911- # data parallel, and create a dummy dp group that is not used.
912- data_parallel_size = world_size // (pipeline_model_parallel_size *
913- tensor_model_parallel_size )
914- has_external_dp = True
915- else :
916- data_parallel_size = config .parallel_config .data_parallel_size
917-
918- # the layout order is: DP x PP x TP
903+ data_parallel_size = config .parallel_config .data_parallel_size
904+
905+ # the layout order is: ExternalDP x DP x PP x TP
906+ # ExternalDP is the data parallel group that is not part of the model,
907+ # every dp rank can generate independently (in verl integration).
908+ # DP is the data parallel group that is part of the model,
909+ # all the ranks in the same DP group should generate simultaneously,
910+ # i.e. the `generate` call in the same DP group should be called together,
911+ # otherwise it will cause deadlock.
919912 # to get group_ranks for each dimension, transpose that dimension to the
920913 # last dimension, then reshape to 2D, then unbind the last dimension
921914 all_ranks = torch .arange (world_size ).reshape (
922- data_parallel_size , pipeline_model_parallel_size ,
915+ - 1 , data_parallel_size , pipeline_model_parallel_size ,
923916 tensor_model_parallel_size ) # noqa
924917
925918 # Build the tensor model-parallel groups.
@@ -939,7 +932,7 @@ def initialize_model_parallel(
939932 global _PP
940933 assert _PP is None , (
941934 "pipeline model parallel group is already initialized" )
942- group_ranks = all_ranks .transpose (1 , 2 ).reshape (
935+ group_ranks = all_ranks .transpose (2 , 3 ).reshape (
943936 - 1 , pipeline_model_parallel_size ).unbind (0 )
944937 group_ranks = [x .tolist () for x in group_ranks ]
945938 _PP = init_model_parallel_group (group_ranks ,
@@ -949,16 +942,10 @@ def initialize_model_parallel(
949942
950943 global _DP
951944 assert _DP is None , ("data parallel group is already initialized" )
952- group_ranks = all_ranks .transpose (0 ,
953- 2 ).reshape (- 1 ,
945+ group_ranks = all_ranks .transpose (1 ,
946+ 3 ).reshape (- 1 ,
954947 data_parallel_size ).unbind (0 )
955948 group_ranks = [x .tolist () for x in group_ranks ]
956- if has_external_dp :
957- # create a dummy dp group that is not used actually,
958- # since this dp is external.
959- # a dummy dp group means every rank is a group itself.
960- # this way, no communication is needed, no memory is wasted.
961- group_ranks = [[x ] for x in range (world_size )]
962949 _DP = init_model_parallel_group (group_ranks ,
963950 get_world_group ().local_rank ,
964951 backend ,
0 commit comments