|
34 | 34 | from vllm.utils import get_open_port |
35 | 35 |
|
36 | 36 |
|
| 37 | +def parse_args(): |
| 38 | + import argparse |
| 39 | + parser = argparse.ArgumentParser(description="Data Parallel Inference") |
| 40 | + parser.add_argument("--model", |
| 41 | + type=str, |
| 42 | + default="ibm-research/PowerMoE-3b", |
| 43 | + help="Model name or path") |
| 44 | + parser.add_argument("--dp-size", |
| 45 | + type=int, |
| 46 | + default=2, |
| 47 | + help="Data parallel size") |
| 48 | + parser.add_argument("--tp-size", |
| 49 | + type=int, |
| 50 | + default=2, |
| 51 | + help="Tensor parallel size") |
| 52 | + parser.add_argument("--node-size", |
| 53 | + type=int, |
| 54 | + default=1, |
| 55 | + help="Total number of nodes") |
| 56 | + parser.add_argument("--node-rank", |
| 57 | + type=int, |
| 58 | + default=0, |
| 59 | + help="Rank of the current node") |
| 60 | + parser.add_argument("--master-addr", |
| 61 | + type=str, |
| 62 | + default="", |
| 63 | + help="Master node IP address") |
| 64 | + parser.add_argument("--master-port", |
| 65 | + type=int, |
| 66 | + default=0, |
| 67 | + help="Master node port") |
| 68 | + return parser.parse_args() |
| 69 | + |
| 70 | + |
37 | 71 | def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, |
38 | 72 | dp_master_port, GPUs_per_dp_rank): |
39 | 73 | os.environ["VLLM_DP_RANK"] = str(global_dp_rank) |
@@ -95,37 +129,8 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, |
95 | 129 |
|
96 | 130 |
|
97 | 131 | if __name__ == "__main__": |
98 | | - import argparse |
99 | | - parser = argparse.ArgumentParser(description="Data Parallel Inference") |
100 | | - parser.add_argument("--model", |
101 | | - type=str, |
102 | | - default="ibm-research/PowerMoE-3b", |
103 | | - help="Model name or path") |
104 | | - parser.add_argument("--dp-size", |
105 | | - type=int, |
106 | | - default=2, |
107 | | - help="Data parallel size") |
108 | | - parser.add_argument("--tp-size", |
109 | | - type=int, |
110 | | - default=2, |
111 | | - help="Tensor parallel size") |
112 | | - parser.add_argument("--node-size", |
113 | | - type=int, |
114 | | - default=1, |
115 | | - help="Total number of nodes") |
116 | | - parser.add_argument("--node-rank", |
117 | | - type=int, |
118 | | - default=0, |
119 | | - help="Rank of the current node") |
120 | | - parser.add_argument("--master-addr", |
121 | | - type=str, |
122 | | - default="", |
123 | | - help="Master node IP address") |
124 | | - parser.add_argument("--master-port", |
125 | | - type=int, |
126 | | - default=0, |
127 | | - help="Master node port") |
128 | | - args = parser.parse_args() |
| 132 | + |
| 133 | + args = parse_args() |
129 | 134 |
|
130 | 135 | dp_size = args.dp_size |
131 | 136 | tp_size = args.tp_size |
|
0 commit comments