|
39 | 39 | from vllm.transformers_utils.utils import check_gguf_file |
40 | 40 | from vllm.usage.usage_lib import UsageContext |
41 | 41 | from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, |
42 | | - GiB_bytes, is_in_ray_actor) |
| 42 | + GiB_bytes, get_ip, is_in_ray_actor) |
43 | 43 |
|
44 | 44 | # yapf: enable |
45 | 45 |
|
@@ -292,6 +292,7 @@ class EngineArgs: |
292 | 292 | data_parallel_size_local: Optional[int] = None |
293 | 293 | data_parallel_address: Optional[str] = None |
294 | 294 | data_parallel_rpc_port: Optional[int] = None |
| 295 | + data_parallel_backend: str = ParallelConfig.data_parallel_backend |
295 | 296 | enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel |
296 | 297 | max_parallel_loading_workers: Optional[ |
297 | 298 | int] = ParallelConfig.max_parallel_loading_workers |
@@ -624,6 +625,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: |
624 | 625 | type=int, |
625 | 626 | help='Port for data parallel RPC ' |
626 | 627 | 'communication.') |
| 628 | + parallel_group.add_argument('--data-parallel-backend', |
| 629 | + '-dpb', |
| 630 | + type=str, |
| 631 | + default='mp', |
| 632 | + help='Backend for data parallel, either ' |
| 633 | + '"mp" or "ray".') |
627 | 634 | parallel_group.add_argument( |
628 | 635 | "--enable-expert-parallel", |
629 | 636 | **parallel_kwargs["enable_expert_parallel"]) |
@@ -1059,23 +1066,37 @@ def create_engine_config( |
1059 | 1066 |
|
1060 | 1067 | # DP address, used in multi-node case for torch distributed group |
1061 | 1068 | # and ZMQ sockets. |
1062 | | - data_parallel_address = self.data_parallel_address if ( |
1063 | | - self.data_parallel_address |
1064 | | - is not None) else ParallelConfig.data_parallel_master_ip |
| 1069 | + if self.data_parallel_address is None: |
| 1070 | + if self.data_parallel_backend == "ray": |
| 1071 | + host_ip = get_ip() |
| 1072 | + logger.info( |
| 1073 | + "Using host IP %s as ray-based data parallel address", |
| 1074 | + host_ip) |
| 1075 | + data_parallel_address = host_ip |
| 1076 | + else: |
| 1077 | + assert self.data_parallel_backend == "mp", ( |
| 1078 | + "data_parallel_backend can only be ray or mp, got %s", |
| 1079 | + self.data_parallel_backend) |
| 1080 | + data_parallel_address = ParallelConfig.data_parallel_master_ip |
| 1081 | + else: |
| 1082 | + data_parallel_address = self.data_parallel_address |
1065 | 1083 |
|
1066 | 1084 | # This port is only used when there are remote data parallel engines, |
1067 | 1085 | # otherwise the local IPC transport is used. |
1068 | 1086 | data_parallel_rpc_port = self.data_parallel_rpc_port if ( |
1069 | 1087 | self.data_parallel_rpc_port |
1070 | 1088 | is not None) else ParallelConfig.data_parallel_rpc_port |
1071 | 1089 |
|
| 1090 | + data_parallel_backend = self.data_parallel_backend |
| 1091 | + |
1072 | 1092 | parallel_config = ParallelConfig( |
1073 | 1093 | pipeline_parallel_size=self.pipeline_parallel_size, |
1074 | 1094 | tensor_parallel_size=self.tensor_parallel_size, |
1075 | 1095 | data_parallel_size=self.data_parallel_size, |
1076 | 1096 | data_parallel_size_local=data_parallel_size_local, |
1077 | 1097 | data_parallel_master_ip=data_parallel_address, |
1078 | 1098 | data_parallel_rpc_port=data_parallel_rpc_port, |
| 1099 | + data_parallel_backend=data_parallel_backend, |
1079 | 1100 | enable_expert_parallel=self.enable_expert_parallel, |
1080 | 1101 | max_parallel_loading_workers=self.max_parallel_loading_workers, |
1081 | 1102 | disable_custom_all_reduce=self.disable_custom_all_reduce, |
|
0 commit comments