Skip to content

Commit e64afa4

Browse files
authored
multi-node offline DP+EP example (#15484)
Signed-off-by: youkaichao <youkaichao@gmail.com>
1 parent 1711b92 commit e64afa4

File tree

1 file changed

+97
-23
lines changed

1 file changed

+97
-23
lines changed

examples/offline_inference/data_parallel.py

Lines changed: 97 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,153 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
# usage:
3-
# VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py
4-
# we need to have a launcher to create multiple data parallel
5-
# ranks. And each rank will create a vLLM instance to process its own prompts.
2+
"""
3+
Usage:
4+
Single node:
5+
python examples/offline_inference/data_parallel.py \
6+
--model="ibm-research/PowerMoE-3b" \
7+
--dp-size=2 \
8+
--tp-size=2
9+
10+
Multi-node:
11+
Node 0 (assume the node has ip of 10.99.48.128):
12+
python examples/offline_inference/data_parallel.py \
13+
--model="ibm-research/PowerMoE-3b" \
14+
--dp-size=2 \
15+
--tp-size=2 \
16+
--node-size=2 \
17+
--node-rank=0 \
18+
--master-addr=10.99.48.128 \
19+
--master-port=13345
20+
Node 1:
21+
python examples/offline_inference/data_parallel.py \
22+
--model="ibm-research/PowerMoE-3b" \
23+
--dp-size=2 \
24+
--tp-size=2 \
25+
--node-size=2 \
26+
--node-rank=1 \
27+
--master-addr=10.99.48.128 \
28+
--master-port=13345
29+
"""
630
import os
731

832
from vllm import LLM, SamplingParams
933
from vllm.utils import get_open_port
1034

11-
GPUs_per_dp_rank = 2
12-
DP_size = 2
13-
1435

15-
def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
16-
os.environ["VLLM_DP_RANK"] = str(dp_rank)
36+
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
37+
dp_master_port, GPUs_per_dp_rank):
38+
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
1739
os.environ["VLLM_DP_SIZE"] = str(dp_size)
1840
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
1941
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
2042
# set devices for each dp_rank
2143
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
22-
str(i) for i in range(dp_rank * GPUs_per_dp_rank, (dp_rank + 1) *
23-
GPUs_per_dp_rank))
44+
str(i)
45+
for i in range(local_dp_rank * GPUs_per_dp_rank, (local_dp_rank + 1) *
46+
GPUs_per_dp_rank))
2447

2548
# Sample prompts.
2649
prompts = [
2750
"Hello, my name is",
2851
"The president of the United States is",
2952
"The capital of France is",
3053
"The future of AI is",
31-
]
54+
] * 100
3255

3356
# with DP, each rank should process different prompts.
3457
# usually all the DP ranks process a full dataset,
3558
# and each rank processes a different part of the dataset.
3659
promts_per_rank = len(prompts) // dp_size
37-
start = dp_rank * promts_per_rank
60+
start = global_dp_rank * promts_per_rank
3861
end = start + promts_per_rank
3962
prompts = prompts[start:end]
4063
if len(prompts) == 0:
4164
# if any rank has no prompts to process,
4265
# we need to set a placeholder prompt
4366
prompts = ["Placeholder"]
44-
print(f"DP rank {dp_rank} needs to process {len(prompts)} prompts")
67+
print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts")
4568

4669
# Create a sampling params object.
4770
# since we are doing data parallel, every rank can have different
4871
# sampling params. here we set different max_tokens for different
4972
# ranks for demonstration.
5073
sampling_params = SamplingParams(temperature=0.8,
5174
top_p=0.95,
52-
max_tokens=16 * (dp_rank + 1))
75+
max_tokens=[16, 20][global_dp_rank % 2])
5376

5477
# Create an LLM.
55-
llm = LLM(model="ibm-research/PowerMoE-3b",
78+
llm = LLM(model=model,
5679
tensor_parallel_size=GPUs_per_dp_rank,
5780
enforce_eager=True,
5881
enable_expert_parallel=True)
5982
outputs = llm.generate(prompts, sampling_params)
6083
# Print the outputs.
61-
for output in outputs:
84+
for i, output in enumerate(outputs):
85+
if i >= 5:
86+
# print only 5 outputs
87+
break
6288
prompt = output.prompt
6389
generated_text = output.outputs[0].text
64-
print(f"DP rank {dp_rank}, Prompt: {prompt!r}, "
90+
print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
6591
f"Generated text: {generated_text!r}")
6692

6793

6894
if __name__ == "__main__":
95+
import argparse
96+
parser = argparse.ArgumentParser(description="Data Parallel Inference")
97+
parser.add_argument("--model",
98+
type=str,
99+
default="ibm-research/PowerMoE-3b",
100+
help="Model name or path")
101+
parser.add_argument("--dp-size",
102+
type=int,
103+
default=2,
104+
help="Data parallel size")
105+
parser.add_argument("--tp-size",
106+
type=int,
107+
default=2,
108+
help="Tensor parallel size")
109+
parser.add_argument("--node-size",
110+
type=int,
111+
default=1,
112+
help="Total number of nodes")
113+
parser.add_argument("--node-rank",
114+
type=int,
115+
default=0,
116+
help="Rank of the current node")
117+
parser.add_argument("--master-addr",
118+
type=str,
119+
default="",
120+
help="Master node IP address")
121+
parser.add_argument("--master-port",
122+
type=int,
123+
default=0,
124+
help="Master node port")
125+
args = parser.parse_args()
126+
127+
dp_size = args.dp_size
128+
tp_size = args.tp_size
129+
node_size = args.node_size
130+
node_rank = args.node_rank
131+
132+
if node_size == 1:
133+
dp_master_ip = "127.0.0.1"
134+
dp_master_port = get_open_port()
135+
else:
136+
dp_master_ip = args.master_addr
137+
dp_master_port = args.master_port
138+
139+
assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
140+
dp_per_node = dp_size // node_size
141+
69142
from multiprocessing import Process
70-
dp_master_ip = "127.0.0.1"
71-
dp_master_port = get_open_port()
143+
72144
procs = []
73-
for i in range(DP_size):
145+
for local_dp_rank, global_dp_rank in enumerate(
146+
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)):
74147
proc = Process(target=main,
75-
args=(DP_size, i, dp_master_ip, dp_master_port,
76-
GPUs_per_dp_rank))
148+
args=(args.model, dp_size, local_dp_rank,
149+
global_dp_rank, dp_master_ip, dp_master_port,
150+
tp_size))
77151
proc.start()
78152
procs.append(proc)
79153
exit_code = 0

0 commit comments

Comments
 (0)