Skip to content

Commit 44ceca5

Browse files
committed
polish code
1 parent 0c82ec6 commit 44ceca5

File tree

4 files changed

+29
-50
lines changed

4 files changed

+29
-50
lines changed

examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ experiment:
1111
port: 10001
1212
use_fs_serve: false
1313
prefill_decode_disaggregation: true
14-
prefill_num: 1
15-
prefill_address: 10.1.1.122 # optional, default "auto"
14+
prefill_num: 2
15+
prefill_address: x.x.x.x # optional, default "auto"
1616
decode_num: 2
17-
decode_address: 10.1.1.108 # optional, default "auto"
17+
decode_address: x.x.x.x # optional, default "auto"
1818
runner:
1919
hostfile: examples/qwen/conf/hostfile.txt
2020
docker: fr-v2

examples/qwen/conf/hostfile.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# ip slots type=xxx[optional]
22
# master node
3-
10.1.1.122 slots=8 type=gpu
3+
x.x.x.x slots=8 type=gpu
44
# worker nodes
5-
10.1.1.108 slots=8 type=gpu
5+
x.x.x.x slots=8 type=gpu

flagscale/runner/runner_serve.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,6 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi
294294
vllm_path = f"{root_dir}/vllm"
295295
deploy_config = config.experiment.get("deploy", {})
296296
envs = config.experiment.get("envs", {})
297-
print(f"shell file ======================== {host_run_script_file}", flush=True)
298297
with open(host_run_script_file, "w") as f:
299298
f.write("#!/bin/bash\n\n")
300299
f.write("set -x\n")
@@ -321,17 +320,14 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi
321320
kv_related_ports = _get_multiple_free_ports(ports_num)
322321
pd_proxy_port = deploy_config.get("pd_proxy_port", None)
323322
if not pd_proxy_port:
324-
raise ValueError(
325-
f"PD disaggregation requires a proxy port to be set."
326-
)
323+
raise ValueError(f"PD disaggregation requires a proxy port to be set.")
327324

328325
engine_args = _get_engine_args(config)
329326
command_items = ["vllm", "serve"]
330327
command_items.append(engine_args["model"])
331328
other_args = flatten_dict_to_args(engine_args, ["model", "port"])
332329
command_items.extend(other_args)
333330
vllm_command = " ".join(command_items)
334-
# vllm_command = "nohup " + vllm_command
335331
if before_start_cmd:
336332
vllm_command = f"{before_start_cmd} && " + vllm_command
337333
if envs_str:
@@ -386,21 +382,18 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi
386382
"http_port": str(http_port),
387383
},
388384
}
389-
print(
385+
logger.info(
390386
f"============= prefill instance {i}, p_kv_config: {p_kv_config} =============",
391387
flush=True,
392388
)
393389
card_ids = resource_manager.get_available_card_ids(
394-
address=p_address,
395-
num=each_instance_card_num,
390+
address=p_address, num=each_instance_card_num
396391
)
397392
card_ids_str = ",".join(map(str, card_ids))
398393
ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}"
399394

400395
p_kv_config_json = json.dumps(p_kv_config)
401-
p_instance_log_path = os.path.join(
402-
default_log_dir, f"prefill_{i}.log"
403-
)
396+
p_instance_log_path = os.path.join(default_log_dir, f"prefill_{i}.log")
404397

405398
if p_address != master_ip:
406399
p_kv_config_formate_json = p_kv_config_json.replace('"', '\\"')
@@ -433,21 +426,18 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi
433426
"http_port": str(http_port),
434427
},
435428
}
436-
print(
429+
logger.info(
437430
f"============= decode instance {i}, d_kv_config: {d_kv_config} =============",
438431
flush=True,
439432
)
440433
card_ids = resource_manager.get_available_card_ids(
441-
address=d_address,
442-
num=each_instance_card_num,
434+
address=d_address, num=each_instance_card_num
443435
)
444436
card_ids_str = ",".join(map(str, card_ids))
445437
ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}"
446438

447439
d_kv_config_json = json.dumps(d_kv_config)
448-
d_instance_log_path = os.path.join(
449-
default_log_dir, f"decode_{j}.log"
450-
)
440+
d_instance_log_path = os.path.join(default_log_dir, f"decode_{j}.log")
451441

452442
if d_address != master_ip:
453443
d_kv_config_formate_json = d_kv_config_json.replace('"', '\\"')
@@ -683,9 +673,7 @@ def _prepare(self):
683673
self.user_envs = self.config.experiment.get("envs", {})
684674
entrypoint = self.config.experiment.task.get("entrypoint", None)
685675
if self.inference_engine: # pd_disagg_router
686-
if self.config.experiment.get("deploy", {}).get(
687-
"prefill_decode_disaggregation", False
688-
):
676+
if self.config.experiment.get("deploy", {}).get("prefill_decode_disaggregation", False):
689677
self.user_script = "flagscale/serve/run_pd_disagg_router.py"
690678
elif not self.use_fs_serve:
691679
self.user_script = "flagscale/serve/run_inference_engine.py"
@@ -783,7 +771,6 @@ def _stop_each(self, host, node_rank):
783771
kill_process_tree(pid)
784772

785773
ray_executable = shutil.which("ray")
786-
print(ray_executable)
787774
if ray_executable:
788775
ray_path = os.path.realpath(ray_executable)
789776
os.system(f"{ray_path} stop")

flagscale/serve/run_pd_disagg_router.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
import logging
1+
# Copyright (c) 2025, BAAI. All rights reserved.
2+
#
3+
# Adopted from https://github.com/vllm-project/vllm/blob/1ad957950ffc1552af5abda78c03d88ddb67945b/examples/online_serving/disagg_xpyd/disagg_prefill_proxy_xpyd.py. Below is the original copyright:
4+
#
5+
# SPDX-License-Identifier: Apache-2.0
6+
#
7+
8+
29
import os
310
import random
411
import socket
@@ -8,6 +15,7 @@
815
import aiohttp
916
import msgpack
1017
import zmq
18+
1119
from quart import Quart, make_response, request
1220

1321
try:
@@ -29,18 +37,13 @@ class LoadManager:
2937
def __init__(self):
3038
self._lock = threading.Lock()
3139
# Each resource type 'P' or 'D' maps to {http_addr: {'zmq': zmq_addr, 'load': int}}
32-
self._instances: dict[str, dict[str, dict[str, object]]] = {
33-
"P": {},
34-
"D": {},
35-
}
40+
self._instances: dict[str, dict[str, dict[str, object]]] = {"P": {}, "D": {}}
3641

3742
def register(self, rtype: str, http_addr: str, zmq_addr: str):
3843
with self._lock:
3944
if http_addr not in self._instances[rtype]:
4045
self._instances[rtype][http_addr] = {"zmq": zmq_addr, "load": 0}
41-
logger.info(
42-
f"Registered new {rtype}-instance {http_addr} (zmq={zmq_addr})"
43-
)
46+
logger.info(f"Registered new {rtype}-instance {http_addr} (zmq={zmq_addr})")
4447
else:
4548
# If zmq address changed, synchronize it
4649
self._instances[rtype][http_addr]["zmq"] = zmq_addr
@@ -67,13 +70,8 @@ def get_random(self, rtype: str) -> tuple[str, str]:
6770

6871
def get_robin_loaded(self, rtype: str) -> tuple[str, str]:
6972
with self._lock:
70-
http_addr, info = min(
71-
self._instances[rtype].items(), key=lambda kv: kv[1]["load"]
72-
)
73-
print(
74-
f"========== whole instance status {self._instances}==========",
75-
flush=True,
76-
)
73+
http_addr, info = min(self._instances[rtype].items(), key=lambda kv: kv[1]["load"])
74+
print(f"========== whole instance status {self._instances}==========", flush=True)
7775
return http_addr, info["zmq"]
7876

7977

@@ -168,9 +166,7 @@ async def forward_request(url, data, request_id):
168166
async def handle_request():
169167
try:
170168
original_data = await request.get_json()
171-
endpoint = (
172-
request.path
173-
) # this will be '/v1/completions' or '/v1/chat/completions'
169+
endpoint = request.path # this will be '/v1/completions' or '/v1/chat/completions'
174170

175171
# Prefill request: max_tokens=1
176172
prefill_request = original_data.copy()
@@ -191,9 +187,7 @@ async def handle_request():
191187
logger.info(f"Selected D-instance {decode_addr} via '{SCHEDULING_STRATEGY}'")
192188

193189
# Keep original request_id composition format
194-
request_id = (
195-
f"___prefill_addr_{prefill_zmq}___decode_addr_{decode_zmq}_{random_uuid()}"
196-
)
190+
request_id = f"___prefill_addr_{prefill_zmq}___decode_addr_{decode_zmq}_{random_uuid()}"
197191

198192
# Execute Prefill and update load
199193
lm.increment_load("P", prefill_addr)
@@ -235,9 +229,7 @@ def main():
235229
raise ValueError("No port specified in deploy config")
236230
if not pd_proxy_port:
237231
raise ValueError("No pd_proxy_port specified in deploy config")
238-
print(
239-
f"Starting Proxy Server...with pd_proxy_port {pd_proxy_port} and serve_port {serve_port}"
240-
)
232+
print(f"Starting Proxy Server...with pd_proxy_port {pd_proxy_port} and serve_port {serve_port}")
241233
listener = start_service_discovery("0.0.0.0", pd_proxy_port)
242234
app.run(host="0.0.0.0", port=serve_port)
243235
listener.join()

0 commit comments

Comments
 (0)