@@ -897,10 +897,23 @@ def initialize_model_parallel(
897897 get_world_group ().device_group )
898898
899899 data_parallel_size = 1
900+ has_external_dp = False
900901 from vllm .config import get_current_vllm_config
901902 config = get_current_vllm_config ()
902903 if config is not None :
903- data_parallel_size = config .parallel_config .data_parallel_size
904+ if config .parallel_config .world_size != world_size :
905+ # detect external data parallelism.
906+ # dp in vllm means all dp instances need to run together.
907+ # if the world size does not match, it means this dp is external,
908+ # and the dp instances can run independently, e.g. in rlhf workflow
909+ # from https://github.com/volcengine/verl .
910+ # in that case, we treat the rest dimensions as if they are
911+ # data parallel, and create a dummy dp group that is not used.
912+ data_parallel_size = world_size // (pipeline_model_parallel_size *
913+ tensor_model_parallel_size )
914+ has_external_dp = True
915+ else :
916+ data_parallel_size = config .parallel_config .data_parallel_size
904917
905918 # the layout order is: DP x PP x TP
906919 # to get group_ranks for each dimension, transpose that dimension to the
@@ -940,6 +953,12 @@ def initialize_model_parallel(
940953 2 ).reshape (- 1 ,
941954 data_parallel_size ).unbind (0 )
942955 group_ranks = [x .tolist () for x in group_ranks ]
956+ if has_external_dp :
957+ # create a dummy dp group that is not used actually,
958+ # since this dp is external.
959+ # a dummy dp group means every rank is a group itself.
960+ # this way, no communication is needed, no memory is wasted.
961+ group_ranks = [[x ] for x in range (world_size )]
943962 _DP = init_model_parallel_group (group_ranks ,
944963 get_world_group ().local_rank ,
945964 backend ,
0 commit comments