|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | | -# usage: |
3 | | -# VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py |
4 | | -# we need to have a launcher to create multiple data parallel |
5 | | -# ranks. And each rank will create a vLLM instance to process its own prompts. |
| 2 | +""" |
| 3 | +Usage: |
| 4 | +Single node: |
| 5 | + python examples/offline_inference/data_parallel.py \ |
| 6 | + --model="ibm-research/PowerMoE-3b" \ |
| 7 | + --dp-size=2 \ |
| 8 | + --tp-size=2 |
| 9 | +
|
| 10 | +Multi-node: |
| 11 | + Node 0 (assume the node has ip of 10.99.48.128): |
| 12 | + python examples/offline_inference/data_parallel.py \ |
| 13 | + --model="ibm-research/PowerMoE-3b" \ |
| 14 | + --dp-size=2 \ |
| 15 | + --tp-size=2 \ |
| 16 | + --node-size=2 \ |
| 17 | + --node-rank=0 \ |
| 18 | + --master-addr=10.99.48.128 \ |
| 19 | + --master-port=13345 |
| 20 | + Node 1: |
| 21 | + python examples/offline_inference/data_parallel.py \ |
| 22 | + --model="ibm-research/PowerMoE-3b" \ |
| 23 | + --dp-size=2 \ |
| 24 | + --tp-size=2 \ |
| 25 | + --node-size=2 \ |
| 26 | + --node-rank=1 \ |
| 27 | + --master-addr=10.99.48.128 \ |
| 28 | + --master-port=13345 |
| 29 | +""" |
6 | 30 | import os |
7 | 31 |
|
8 | 32 | from vllm import LLM, SamplingParams |
9 | 33 | from vllm.utils import get_open_port |
10 | 34 |
|
11 | | -GPUs_per_dp_rank = 2 |
12 | | -DP_size = 2 |
13 | | - |
14 | 35 |
|
15 | | -def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank): |
16 | | - os.environ["VLLM_DP_RANK"] = str(dp_rank) |
| 36 | +def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, |
| 37 | + dp_master_port, GPUs_per_dp_rank): |
| 38 | + os.environ["VLLM_DP_RANK"] = str(global_dp_rank) |
17 | 39 | os.environ["VLLM_DP_SIZE"] = str(dp_size) |
18 | 40 | os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip |
19 | 41 | os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port) |
20 | 42 | # set devices for each dp_rank |
21 | 43 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( |
22 | | - str(i) for i in range(dp_rank * GPUs_per_dp_rank, (dp_rank + 1) * |
23 | | - GPUs_per_dp_rank)) |
| 44 | + str(i) |
| 45 | + for i in range(local_dp_rank * GPUs_per_dp_rank, (local_dp_rank + 1) * |
| 46 | + GPUs_per_dp_rank)) |
24 | 47 |
|
25 | 48 | # Sample prompts. |
26 | 49 | prompts = [ |
27 | 50 | "Hello, my name is", |
28 | 51 | "The president of the United States is", |
29 | 52 | "The capital of France is", |
30 | 53 | "The future of AI is", |
31 | | - ] |
| 54 | + ] * 100 |
32 | 55 |
|
33 | 56 | # with DP, each rank should process different prompts. |
34 | 57 | # usually all the DP ranks process a full dataset, |
35 | 58 | # and each rank processes a different part of the dataset. |
36 | 59 | promts_per_rank = len(prompts) // dp_size |
37 | | - start = dp_rank * promts_per_rank |
| 60 | + start = global_dp_rank * promts_per_rank |
38 | 61 | end = start + promts_per_rank |
39 | 62 | prompts = prompts[start:end] |
40 | 63 | if len(prompts) == 0: |
41 | 64 | # if any rank has no prompts to process, |
42 | 65 | # we need to set a placeholder prompt |
43 | 66 | prompts = ["Placeholder"] |
44 | | - print(f"DP rank {dp_rank} needs to process {len(prompts)} prompts") |
| 67 | + print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts") |
45 | 68 |
|
46 | 69 | # Create a sampling params object. |
47 | 70 | # since we are doing data parallel, every rank can have different |
48 | 71 | # sampling params. here we set different max_tokens for different |
49 | 72 | # ranks for demonstration. |
50 | 73 | sampling_params = SamplingParams(temperature=0.8, |
51 | 74 | top_p=0.95, |
52 | | - max_tokens=16 * (dp_rank + 1)) |
| 75 | + max_tokens=[16, 20][global_dp_rank % 2]) |
53 | 76 |
|
54 | 77 | # Create an LLM. |
55 | | - llm = LLM(model="ibm-research/PowerMoE-3b", |
| 78 | + llm = LLM(model=model, |
56 | 79 | tensor_parallel_size=GPUs_per_dp_rank, |
57 | 80 | enforce_eager=True, |
58 | 81 | enable_expert_parallel=True) |
59 | 82 | outputs = llm.generate(prompts, sampling_params) |
60 | 83 | # Print the outputs. |
61 | | - for output in outputs: |
| 84 | + for i, output in enumerate(outputs): |
| 85 | + if i >= 5: |
| 86 | + # print only 5 outputs |
| 87 | + break |
62 | 88 | prompt = output.prompt |
63 | 89 | generated_text = output.outputs[0].text |
64 | | - print(f"DP rank {dp_rank}, Prompt: {prompt!r}, " |
| 90 | + print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " |
65 | 91 | f"Generated text: {generated_text!r}") |
66 | 92 |
|
67 | 93 |
|
68 | 94 | if __name__ == "__main__": |
| 95 | + import argparse |
| 96 | + parser = argparse.ArgumentParser(description="Data Parallel Inference") |
| 97 | + parser.add_argument("--model", |
| 98 | + type=str, |
| 99 | + default="ibm-research/PowerMoE-3b", |
| 100 | + help="Model name or path") |
| 101 | + parser.add_argument("--dp-size", |
| 102 | + type=int, |
| 103 | + default=2, |
| 104 | + help="Data parallel size") |
| 105 | + parser.add_argument("--tp-size", |
| 106 | + type=int, |
| 107 | + default=2, |
| 108 | + help="Tensor parallel size") |
| 109 | + parser.add_argument("--node-size", |
| 110 | + type=int, |
| 111 | + default=1, |
| 112 | + help="Total number of nodes") |
| 113 | + parser.add_argument("--node-rank", |
| 114 | + type=int, |
| 115 | + default=0, |
| 116 | + help="Rank of the current node") |
| 117 | + parser.add_argument("--master-addr", |
| 118 | + type=str, |
| 119 | + default="", |
| 120 | + help="Master node IP address") |
| 121 | + parser.add_argument("--master-port", |
| 122 | + type=int, |
| 123 | + default=0, |
| 124 | + help="Master node port") |
| 125 | + args = parser.parse_args() |
| 126 | + |
| 127 | + dp_size = args.dp_size |
| 128 | + tp_size = args.tp_size |
| 129 | + node_size = args.node_size |
| 130 | + node_rank = args.node_rank |
| 131 | + |
| 132 | + if node_size == 1: |
| 133 | + dp_master_ip = "127.0.0.1" |
| 134 | + dp_master_port = get_open_port() |
| 135 | + else: |
| 136 | + dp_master_ip = args.master_addr |
| 137 | + dp_master_port = args.master_port |
| 138 | + |
| 139 | + assert dp_size % node_size == 0, "dp_size should be divisible by node_size" |
| 140 | + dp_per_node = dp_size // node_size |
| 141 | + |
69 | 142 | from multiprocessing import Process |
70 | | - dp_master_ip = "127.0.0.1" |
71 | | - dp_master_port = get_open_port() |
| 143 | + |
72 | 144 | procs = [] |
73 | | - for i in range(DP_size): |
| 145 | + for local_dp_rank, global_dp_rank in enumerate( |
| 146 | + range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)): |
74 | 147 | proc = Process(target=main, |
75 | | - args=(DP_size, i, dp_master_ip, dp_master_port, |
76 | | - GPUs_per_dp_rank)) |
| 148 | + args=(args.model, dp_size, local_dp_rank, |
| 149 | + global_dp_rank, dp_master_ip, dp_master_port, |
| 150 | + tp_size)) |
77 | 151 | proc.start() |
78 | 152 | procs.append(proc) |
79 | 153 | exit_code = 0 |
|
0 commit comments