From 45ab9f5399f627acec4174e0607cdc55c7644e71 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 15 Apr 2025 16:01:24 +0800 Subject: [PATCH 01/62] launch pd disagg --- examples/deepseek/conf/hostfile.txt | 4 +- .../config_qwen2.5_7b_pd_disaggregation.yaml | 26 +++ flagscale/serve/pd_disagg_router.py | 214 ++++++++++++++++++ 3 files changed, 242 insertions(+), 2 deletions(-) create mode 100644 examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml create mode 100644 flagscale/serve/pd_disagg_router.py diff --git a/examples/deepseek/conf/hostfile.txt b/examples/deepseek/conf/hostfile.txt index 0d8b1e05f..e49bb3049 100644 --- a/examples/deepseek/conf/hostfile.txt +++ b/examples/deepseek/conf/hostfile.txt @@ -1,5 +1,5 @@ # ip slots type=xxx[optional] # master node -x.x.x.x slots=8 type=gpu +10.1.1.122 slots=8 type=gpu # worker nodes -x.x.x.x slots=8 type=gpu +10.1.1.108 slots=8 type=gpu diff --git a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml new file mode 100644 index 000000000..64ef15565 --- /dev/null +++ b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml @@ -0,0 +1,26 @@ +defaults: +- _self_ +- serve: serve_qwen2.5_7b + +experiment: + exp_name: qwen2.5_7b + exp_dir: outputs/${experiment.exp_name} + task: + type: serve + deploy: + use_fs_serve: false + prefill_decode_disaggregation: true + prefill_num: 1 + prefill_address: 127.0.0.1 + decode_num: 2 + decode_address: 127.0.0.1 + runner: + hostfile: examples/deepseek/conf/hostfile.txt + envs: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra diff --git a/flagscale/serve/pd_disagg_router.py b/flagscale/serve/pd_disagg_router.py new file mode 100644 index 000000000..f909edc3c --- /dev/null +++ b/flagscale/serve/pd_disagg_router.py @@ -0,0 +1,214 @@ +import os +import random +import socket +import threading +import uuid + +import aiohttp +import msgpack +import zmq +from quart import Quart, make_response, request + +prefill_instances: dict[str, str] = {} # http_address: zmq_address +decode_instances: dict[str, str] = {} # http_address: zmq_address + +prefill_cv = threading.Condition() +decode_cv = threading.Condition() + + +def _listen_for_register(poller, router_socket): + while True: + socks = dict(poller.poll()) + if router_socket in socks: + remote_address, message = router_socket.recv_multipart() + # data: {"type": "P", "http_address": "ip:port", + # "zmq_address": "ip:port"} + data = msgpack.loads(message) + # print("Received message from %s, data: %s", + # remote_address.decode(), data) + if data["type"] == "P": + global prefill_instances + global prefill_cv + with prefill_cv: + prefill_instances[data["http_address"]] = data["zmq_address"] + elif data["type"] == "D": + global decode_instances + global decode_cv + with decode_cv: + decode_instances[data["http_address"]] = data["zmq_address"] + else: + print( + "Unexpected, Received message from %s, data: %s", + remote_address, + data, + ) + + +def start_service_discovery(hostname, port): + if not hostname: + hostname = socket.gethostname() + if port == 0: + raise ValueError("Port cannot be 0") + + context = zmq.Context() + router_socket = context.socket(zmq.ROUTER) + router_socket.bind(f"tcp://{hostname}:{port}") + + poller = zmq.Poller() + poller.register(router_socket, zmq.POLLIN) + + _listener_thread = threading.Thread( + target=_listen_for_register, args=[poller, router_socket], daemon=True + ) + _listener_thread.start() + return _listener_thread + + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +app = Quart(__name__) + + +def random_uuid() -> str: + return str(uuid.uuid4().hex) + + +async def forward_request(url, data, request_id): + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + } + async with session.post(url=url, json=data, headers=headers) as response: + if response.status == 200: + # if response.headers.get('Transfer-Encoding') == 'chunked': + if True: + async for chunk_bytes in response.content.iter_chunked(1024): + yield chunk_bytes + else: + content = await response.read() + yield content + + +@app.route("/v1/completions", methods=["POST"]) +async def handle_request(): + try: + original_request_data = await request.get_json() + + prefill_request = original_request_data.copy() + # change max_tokens = 1 to let it only do prefill + prefill_request["max_tokens"] = 1 + + global prefill_instances + global prefill_cv + with prefill_cv: + prefill_addr, prefill_zmq_addr = random.choice( + list(prefill_instances.items()) + ) + print( + "handle_request, prefill_addr: %s, zmq_addr: %s", + prefill_addr, + prefill_zmq_addr, + ) + + global decode_instances + global decode_cv + with decode_cv: + decode_addr, decode_zmq_addr = random.choice(list(decode_instances.items())) + print( + "handle_request, decode_addr: %s, zmq_addr: %s", + decode_addr, + decode_zmq_addr, + ) + + request_id = f"___prefill_addr_{prefill_zmq_addr}___decode_addr_{decode_zmq_addr}_{random_uuid()}" + + # finish prefill + async for _ in forward_request( + f"http://{prefill_addr}/v1/completions", prefill_request, request_id + ): + continue + + # return decode + generator = forward_request( + f"http://{decode_addr}/v1/completions", original_request_data, request_id + ) + response = await make_response(generator) + response.timeout = None + + return response + + except Exception as e: + import sys + import traceback + + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server") + print(e) + print("".join(traceback.format_exception(*exc_info))) + + +@app.route("/v1/chat/completions", methods=["POST"]) +async def handle_request(): + try: + original_request_data = await request.get_json() + + prefill_request = original_request_data.copy() + # change max_tokens = 1 to let it only do prefill + prefill_request["max_tokens"] = 1 + + global prefill_instances + global prefill_cv + with prefill_cv: + prefill_addr, prefill_zmq_addr = random.choice( + list(prefill_instances.items()) + ) + print( + "handle_request, prefill_addr: %s, zmq_addr: %s", + prefill_addr, + prefill_zmq_addr, + ) + + global decode_instances + global decode_cv + with decode_cv: + decode_addr, decode_zmq_addr = random.choice(list(decode_instances.items())) + print( + "handle_request, decode_addr: %s, zmq_addr: %s", + decode_addr, + decode_zmq_addr, + ) + + request_id = f"___prefill_addr_{prefill_zmq_addr}___decode_addr_{decode_zmq_addr}_{random_uuid()}" + + # finish prefill + async for _ in forward_request( + f"http://{prefill_addr}/v1/chat/completions", prefill_request, request_id + ): + continue + + # return decode + generator = forward_request( + f"http://{decode_addr}/v1/chat/completions", + original_request_data, + request_id, + ) + response = await make_response(generator) + response.timeout = None + + return response + + except Exception as e: + import sys + import traceback + + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server") + print(e) + print("".join(traceback.format_exception(*exc_info))) + + +if __name__ == "__main__": + t = start_service_discovery("0.0.0.0", 30001) + app.run(host="0.0.0.0", port=10001) + t.join() From 7a7076ece1b193f390117aa980278310bd98072a Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 00:35:16 +0800 Subject: [PATCH 02/62] merge main --- .../config_qwen2.5_7b_pd_disaggregation.yaml | 8 +- examples/qwen/conf/refer.sh | 34 ++ examples/qwen/conf/target.sh | 36 ++ flagscale/runner/runner_serve.py | 358 +++++++++++++++--- 4 files changed, 384 insertions(+), 52 deletions(-) create mode 100644 examples/qwen/conf/refer.sh create mode 100644 examples/qwen/conf/target.sh diff --git a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml index 64ef15565..25b8144b4 100644 --- a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml +++ b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml @@ -8,16 +8,20 @@ experiment: task: type: serve deploy: + port: 10001 use_fs_serve: false prefill_decode_disaggregation: true prefill_num: 1 - prefill_address: 127.0.0.1 + prefill_address: 10.1.1.122 decode_num: 2 - decode_address: 127.0.0.1 + decode_address: 10.1.1.108 runner: hostfile: examples/deepseek/conf/hostfile.txt + docker: fr-v2 envs: CUDA_DEVICE_MAX_CONNECTIONS: 1 + cmds: + before_start: source /root/miniconda3/bin/activate flagscale-inference action: run diff --git a/examples/qwen/conf/refer.sh b/examples/qwen/conf/refer.sh new file mode 100644 index 000000000..f0668435a --- /dev/null +++ b/examples/qwen/conf/refer.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +set -x + +source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 + +if [ -z "$PYTHONPATH" ]; then + export PYTHONPATH=/root/miniconda3/envs/flagscale-inference/lib/python3.12/site-packages:/mine/ip122/tune_qwen/github_flagscale +else + export PYTHONPATH="$PYTHONPATH:/root/miniconda3/envs/flagscale-inference/lib/python3.12/site-packages:/mine/ip122/tune_qwen/github_flagscale" +fi + +ray_path=$(realpath $(which ray)) +# clean nodes +ssh -n -p 22 10.1.1.108 "docker exec ds /bin/bash -c 'source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 && ${ray_path} stop'" +source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 && ${ray_path} stop +pkill -f 'run_inference_engine' +pkill -f 'run_fs_serve_vllm' +pkill -f 'vllm serve' + +# start cluster +# master node +source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 && ${ray_path} start --head --port=59081 --num-gpus=8 + +# worker nodes +ssh -n -p 22 10.1.1.108 "docker exec ds /bin/bash -c 'source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 && ${ray_path} start --address=10.1.1.122:59081 --num-gpus=8'" +mkdir -p /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs +mkdir -p /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs/pids + +cd /mine/ip122/tune_qwen/github_flagscale + +cmd="CUDA_DEVICE_MAX_CONNECTIONS=1 python flagscale/serve/run_inference_engine.py --config-path=/mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs/scripts/serve.yaml --log-dir=/mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs" + +nohup bash -c "$cmd; sync" >> /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs/host_0_localhost.output 2>&1 & echo $! > /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs/pids/host_0_localhost.pid \ No newline at end of file diff --git a/examples/qwen/conf/target.sh b/examples/qwen/conf/target.sh new file mode 100644 index 000000000..abbc609a2 --- /dev/null +++ b/examples/qwen/conf/target.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +set -x + +MODEL_NAME=/models/Qwen2.5-7B-Instruct + +source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 + +if [ -z "$PYTHONPATH" ]; then + export PYTHONPATH=/root/miniconda3/envs/flagscale-inference/lib/python3.12/site-packages:/mine/ip122/tune_qwen/github_flagscale +else + export PYTHONPATH="$PYTHONPATH:/root/miniconda3/envs/flagscale-inference/lib/python3.12/site-packages:/mine/ip122/tune_qwen/github_flagscale" +fi + +ray_path=$(realpath $(which ray)) +# clean nodes +ssh -n -p 22 10.1.1.108 "docker exec ds /bin/bash -c 'source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 && ${ray_path} stop'" +source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 && ${ray_path} stop +pkill -f 'run_inference_engine' +pkill -f 'run_fs_serve_vllm' +pkill -f 'vllm serve' + +# start cluster +# master node +source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 && ${ray_path} start --head --port=59081 --num-gpus=8 + +# worker nodes +ssh -n -p 22 10.1.1.108 "docker exec ds /bin/bash -c 'source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 && ${ray_path} start --address=10.1.1.122:59081 --num-gpus=8'" +mkdir -p /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs +mkdir -p /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs/pids + +cd /mine/ip122/tune_qwen/github_flagscale + +cmd="CUDA_DEVICE_MAX_CONNECTIONS=1 python flagscale/serve/run_inference_engine.py --config-path=/mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs/scripts/serve.yaml --log-dir=/mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs" + +nohup bash -c "$cmd; sync" >> /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs/host_0_localhost.output 2>&1 & echo $! > /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs/pids/host_0_localhost.pid \ No newline at end of file diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 2d2308a41..2638ffc71 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import copy import json import os import shlex @@ -14,6 +15,7 @@ from flagscale.runner.utils import ( benchmark, dummy_random_input, + flatten_dict_to_args, get_free_port, get_nproc_per_node, logger, @@ -22,6 +24,141 @@ ) +def _get_multiple_free_ports(num=1, exclude_ports=[]): + allocated_ports = [] + for i in range(num): + port = get_free_port() + while port in allocated_ports or port in exclude_ports: + port = get_free_port() + allocated_ports.append(port) + return allocated_ports + + +class ResourceManager: + def __init__(self, nodes): + """ + Initialize the ResourceManager with a list of nodes. + Each element in the list should be a two-item list: + - The first item is the node address (a string). + - The second item is a dictionary containing at least the key "slots". + If "type" is not provided, it defaults to "gpu" with a warning. + The first node is treated as the master node, and the rest are worker nodes. + """ + self.nodes = self._initialize_nodes(nodes) + + def _initialize_nodes(self, nodes): + """ + Convert the input nodes list into the internal nodes representation. + Each node is converted into a dictionary with keys: + "address", "slots", "type", and "used" (initialized to 0). + If the "type" is not provided in a node, default it to "gpu" and issue a warning. + """ + initialized_nodes = [] + for node in nodes: + if len(node) != 2: + raise ValueError("Each node must include an address and node data") + address, info = node + if "slots" not in info: + raise ValueError("Node data must contain 'slots'") + if "type" not in info: + logger.warning( + f"Node {address} does not provide a resource type. Defaulting to 'gpu'." + ) + resource_type = info.get("type", "gpu") + initialized_nodes.append( + { + "address": address, + "slots": info["slots"], + "type": resource_type, + "used": 0, # Initialize used slot count to 0 + } + ) + return initialized_nodes + + def whole_card_num(self, resource_type="gpu"): + """ + Return the total number of slots across all nodes with the specified resource type. + The return type is int. + """ + total = 0 + for node in self.nodes: + if node["type"] == resource_type: + total += node["slots"] + return total + + def available_card_num(self, resource_type="gpu"): + """ + Return the total number of available slots (slots minus used) across all nodes with the specified resource type. + The return type is int. + """ + total = 0 + for node in self.nodes: + if node["type"] == resource_type: + total += node["slots"] - node["used"] + return total + + def available_card_ids(self, resource_type="gpu", address="auto", num=1): + """ + Allocate 'num' resource cards from a node and return a list of card indices. + + For the default case (address="auto"), traverse nodes in order: master node first, then worker nodes. + - If a node's available slots (slots - used) are >= num, allocate num consecutive indices (based on the current used value) + and update the node's used count, returning the allocated indices (0-indexed) as a list. + - If the available slots are insufficient at a particular node and address is "auto", continue searching through other nodes. + - If an explicit address is provided, check only that node; if it doesn't exist or lacks sufficient available slots, raise an error. + - If none of the nodes can satisfy the request, raise an error indicating insufficient resources. + """ + # Check the specified node if address is not "auto" + if address != "auto": + node_found = None + for node in self.nodes: + if node["address"] == address and node["type"] == resource_type: + node_found = node + break + if node_found is None: + raise ValueError(f"Node {address} does not exist or resource type mismatch") + free = node_found["slots"] - node_found["used"] + if free < num: + raise ValueError("Insufficient resources") + allocated_ids = list(range(node_found["used"], node_found["used"] + num)) + node_found["used"] += num + return allocated_ids + + # For address == "auto", traverse all nodes (master node first, then worker nodes) + for node in self.nodes: + if node["type"] == resource_type: + free = node["slots"] - node["used"] + if free >= num: + allocated_ids = list(range(node["used"], node["used"] + num)) + node["used"] += num + return allocated_ids + + # If no node satisfies the allocation request, raise an error. + resource_status = self.get_status() + raise ValueError( + f"Require number {num} of resource_type {resource_type} But there is insufficient resources: \n{resource_status}" + ) + + def get_status(self): + """ + Return the status of all nodes as a dictionary. + Each key in the returned dictionary is the node's address, and its value is a dictionary with: + - type: the resource type. + - slots: the total number of slots. + - used: the number of allocated slots. + - available: the number of available slots (slots - used). + """ + status = {} + for node in self.nodes: + status[node["address"]] = { + "type": node["type"], + "slots": node["slots"], + "used": node["used"], + "available": node["slots"] - node["used"], + } + return status + + def _get_args_vllm(config: DictConfig): # see the following link for more details # https://github.com/facebookresearch/hydra/discussions/2750 @@ -150,6 +287,8 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi vllm_path = os.path.dirname(vllm.__path__[0]) except Exception as e: vllm_path = f"{root_dir}/vllm" + deploy_config = config.experiment.get("deploy", {}) + print(f"shell file ======================== {host_run_script_file}", flush=True) with open(host_run_script_file, "w") as f: f.write("#!/bin/bash\n\n") f.write("set -x\n") @@ -165,13 +304,157 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi f.write(f"\n") if nodes: - f.write(f"ray_path=$(realpath $(which ray))\n") - master_ip = nodes[0][0] - target_port = nodes[0][1].get("port") + if deploy_config.get("prefill_decode_disaggregation", False): + resource_manager = ResourceManager(nodes) + master_ip = nodes[0][0] + target_port = nodes[0][1].get("port") + p_num = deploy_config.get("prefill_num", 1) + d_num = deploy_config.get("decode_num", 1) + ports_num = (p_num + d_num) * 2 + 1 + kv_related_ports = _get_multiple_free_ports(ports_num) + kv_proxy_port = kv_related_ports.pop() + + engine_args = _get_engine_args(config) + command_items = ["vllm", "serve"] + command_items.append(engine_args["model"]) + other_args = flatten_dict_to_args(engine_args, ["model", "port"]) + command_items.extend(other_args) + vllm_command = " ".join(command_items) + if before_start_cmd: + vllm_command = f"{before_start_cmd} && " + node_cmd + p_address = deploy_config.get("prefill_address", "127.0.0.1") + d_address = deploy_config.get("decode_address", "127.0.0.1") + tensor_parallel_size = deploy_config.get("tensor_parallel_size", 1) + pipeline_parallel_size = deploy_config.get("pipeline_parallel_size", 1) + each_instance_card_num = tensor_parallel_size * pipeline_parallel_size + + f.write(f"# clean nodes \n") + if len(nodes) > 1: + for ip, node in nodes[1:]: + if not node.get("type", None): + raise ValueError( + f"Node type must be specified for node {node}. Available types are 'cpu', 'gpu', or a custom resource name." + ) + if not node.get("slots", None): + raise ValueError( + f"Number of slots must be specified for node {node}. This can be done by setting the 'slots' attribute." + ) + node_cmd = f"pkill -f 'vllm serve'" + + if before_start_cmd: + node_cmd = f"{before_start_cmd} && " + node_cmd + + ssh_cmd = f'ssh -n -p {ssh_port} {ip} "{node_cmd}"' + + if docker_name: + ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" + f.write(f"{ssh_cmd}\n") + if before_start_cmd: + f.write(f"{before_start_cmd} && pkill -f 'vllm serve'\n") + else: + f.write(f"pkill -f 'vllm serve'\n") + f.write("pkill -f 'run_inference_engine'\n") + f.write("pkill -f 'run_fs_serve_vllm'\n") + f.write("pkill -f 'vllm serve'\n") + f.write(f"\n") + + for _ in range(p_num): + kv_port = kv_related_ports.pop() + http_port = kv_related_ports.pop() + p_kv_config = { + "kv_connector": "P2pConnector", + "kv_role": "kv_producer", + "kv_port": kv_port, + "kv_connector_extra_config": { + "proxy_ip": master_ip, + "proxy_port": kv_proxy_port, + "http_port": http_port, + }, + } + card_ids = resource_manager.get_available_card_ids( + node_type=node["type"], slot_count=each_instance_card_num + ) + card_ids_str = ",".join(map(str, card_ids)) + ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}" + + p_kv_config_json = json.dumps(p_kv_config).replace('"', '\\"') + + node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{p_kv_config_json}'\\''" + if p_address != master_ip: + ssh_cmd = f'ssh -n -p {ssh_port} {p_address} "{node_cmd}"' + if docker_name: + ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" + f.write(f"{ssh_cmd}\n") + else: + f.write(f"{node_cmd}\n") + + for _ in range(d_num): + kv_port = kv_related_ports.pop() + http_port = kv_related_ports.pop() + d_kv_config = { + "kv_connector": "P2pConnector", + "kv_role": "kv_consumer", + "kv_port": kv_port, + "kv_connector_extra_config": { + "proxy_ip": master_ip, + "proxy_port": kv_proxy_port, + "http_port": http_port, + }, + } + card_ids = resource_manager.get_available_card_ids( + node_type=node["type"], slot_count=each_instance_card_num + ) + card_ids_str = ",".join(map(str, card_ids)) + ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}" + + d_kv_config_json = json.dumps(d_kv_config).replace('"', '\\"') + node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{d_kv_config_json}'\\''" + if d_address != master_ip: + ssh_cmd = f'ssh -n -p {ssh_port} {d_address} "{node_cmd}"' + if docker_name: + ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" + else: + f.write(f"{node_cmd}\n") + + else: + f.write(f"ray_path=$(realpath $(which ray))\n") + master_ip = nodes[0][0] + target_port = nodes[0][1].get("port") + + f.write(f"# clean nodes \n") + if len(nodes) > 1: + for ip, node in nodes[1:]: + if not node.get("type", None): + raise ValueError( + f"Node type must be specified for node {node}. Available types are 'cpu', 'gpu', or a custom resource name." + ) + if not node.get("slots", None): + raise ValueError( + f"Number of slots must be specified for node {node}. This can be done by setting the 'slots' attribute." + ) + node_cmd = f"${{ray_path}} stop" + + if before_start_cmd: + node_cmd = f"{before_start_cmd} && " + node_cmd + + ssh_cmd = f'ssh -n -p {ssh_port} {ip} "{node_cmd}"' + + if docker_name: + ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" + f.write(f"{ssh_cmd}\n") + if before_start_cmd: + f.write(f"{before_start_cmd} && ${{ray_path}} stop\n") + else: + f.write(f"${{ray_path}} stop\n") + f.write("pkill -f 'run_inference_engine'\n") + f.write("pkill -f 'run_fs_serve_vllm'\n") + f.write("pkill -f 'vllm serve'\n") + f.write(f"\n") + + master_port = target_port if target_port else get_free_port() - f.write(f"# clean nodes \n") - if len(nodes) > 1: - for ip, node in nodes[1:]: + address = f"{master_ip}:{master_port}" + for index, (ip, node) in enumerate(nodes): if not node.get("type", None): raise ValueError( f"Node type must be specified for node {node}. Available types are 'cpu', 'gpu', or a custom resource name." @@ -180,45 +463,21 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi raise ValueError( f"Number of slots must be specified for node {node}. This can be done by setting the 'slots' attribute." ) - node_cmd = f"${{ray_path}} stop" + if index == 0: + # master node + f.write(f"# start cluster\n") + f.write(f"# master node\n") + if node.type == "gpu": + node_cmd = f"${{ray_path}} start --head --port={master_port} --num-gpus={node.slots}" + elif node.type == "cpu": + node_cmd = f"${{ray_path}} start --head --port={master_port} --num-cpus={node.slots}" + else: + resource = json.dumps({node.type: node.slots}).replace('"', '\\"') + node_cmd = f"${{ray_path}} start --head --port={master_port} --resources='{resource}'" + if before_start_cmd: + node_cmd = f"{before_start_cmd} && " + node_cmd + f.write(f"{node_cmd}\n") - if before_start_cmd: - node_cmd = f"{before_start_cmd} && " + node_cmd - - ssh_cmd = f'ssh -n -p {ssh_port} {ip} "{node_cmd}"' - - if docker_name: - ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" - f.write(f"{ssh_cmd}\n") - if before_start_cmd: - f.write(f"{before_start_cmd} && ${{ray_path}} stop\n") - else: - f.write(f"${{ray_path}} stop\n") - f.write("pkill -f 'run_inference_engine'\n") - f.write("pkill -f 'run_fs_serve_vllm'\n") - f.write("pkill -f 'vllm serve'\n") - f.write(f"\n") - - master_port = target_port if target_port else get_free_port() - - address = f"{master_ip}:{master_port}" - for index, (ip, node) in enumerate(nodes): - if not node.get("type", None): - raise ValueError( - f"Node type must be specified for node {node}. Available types are 'cpu', 'gpu', or a custom resource name." - ) - if not node.get("slots", None): - raise ValueError( - f"Number of slots must be specified for node {node}. This can be done by setting the 'slots' attribute." - ) - if index == 0: - # master node - f.write(f"# start cluster\n") - f.write(f"# master node\n") - if node.type == "gpu": - node_cmd = f"${{ray_path}} start --head --port={master_port} --num-gpus={node.slots}" - elif node.type == "cpu": - node_cmd = f"${{ray_path}} start --head --port={master_port} --num-cpus={node.slots}" else: resource = json.dumps({node.type: node.slots}).replace('"', '\\"') node_cmd = f"${{ray_path}} start --head --port={master_port} --resources='{resource}'" @@ -247,12 +506,11 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi ) if before_start_cmd: node_cmd = f"{before_start_cmd} && " + node_cmd + ssh_cmd = f'ssh -n -p {ssh_port} {ip} "{node_cmd}"' - ssh_cmd = f'ssh -n -p {ssh_port} {ip} "{node_cmd}"' - - if docker_name: - ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" - f.write(f"{ssh_cmd}\n") + if docker_name: + ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" + f.write(f"{ssh_cmd}\n") else: # Note: config key device_type is specified for single node serving in neither gpu or cpu. device_type = None @@ -267,7 +525,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi f"nproc_per_node must be specified when device_type {device_type} is specified." ) node_cmd = None - deploy_config = config.experiment.get("deploy", {}) + if deploy_config.get("use_fs_serve", True) and config.serve[0].get("engine", None): f.write(f"ray_path=$(realpath $(which ray))\n") if not device_type: From 87a0936b834564538c92250dfdefa299849008ff Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Wed, 16 Apr 2025 18:37:42 +0800 Subject: [PATCH 03/62] fix code --- flagscale/runner/runner_serve.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 2638ffc71..0a2192e2f 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -97,7 +97,7 @@ def available_card_num(self, resource_type="gpu"): total += node["slots"] - node["used"] return total - def available_card_ids(self, resource_type="gpu", address="auto", num=1): + def get_available_card_ids(self, resource_type="gpu", address="auto", num=1): """ Allocate 'num' resource cards from a node and return a list of card indices. @@ -321,7 +321,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi command_items.extend(other_args) vllm_command = " ".join(command_items) if before_start_cmd: - vllm_command = f"{before_start_cmd} && " + node_cmd + vllm_command = f"{before_start_cmd} && " + vllm_command p_address = deploy_config.get("prefill_address", "127.0.0.1") d_address = deploy_config.get("decode_address", "127.0.0.1") tensor_parallel_size = deploy_config.get("tensor_parallel_size", 1) From 384f92dd94683abfb1dad13e93d212f73adf3b2b Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Wed, 16 Apr 2025 18:39:44 +0800 Subject: [PATCH 04/62] fix code --- flagscale/runner/runner_serve.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 0a2192e2f..5a1b3bf47 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -75,7 +75,7 @@ def _initialize_nodes(self, nodes): ) return initialized_nodes - def whole_card_num(self, resource_type="gpu"): + def get_whole_card_num(self, resource_type="gpu"): """ Return the total number of slots across all nodes with the specified resource type. The return type is int. @@ -86,7 +86,7 @@ def whole_card_num(self, resource_type="gpu"): total += node["slots"] return total - def available_card_num(self, resource_type="gpu"): + def get_available_card_num(self, resource_type="gpu"): """ Return the total number of available slots (slots minus used) across all nodes with the specified resource type. The return type is int. @@ -372,7 +372,8 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi }, } card_ids = resource_manager.get_available_card_ids( - node_type=node["type"], slot_count=each_instance_card_num + resource_type=node["type"], + num=each_instance_card_num, ) card_ids_str = ",".join(map(str, card_ids)) ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}" From 8b647cd844f9dcb2699343ca07f23fb6da721431 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 00:38:14 +0800 Subject: [PATCH 05/62] fix confict --- flagscale/runner/runner_serve.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 5a1b3bf47..0e8296814 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -372,8 +372,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi }, } card_ids = resource_manager.get_available_card_ids( - resource_type=node["type"], - num=each_instance_card_num, + resource_type=node["type"], num=each_instance_card_num ) card_ids_str = ",".join(map(str, card_ids)) ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}" @@ -403,7 +402,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi }, } card_ids = resource_manager.get_available_card_ids( - node_type=node["type"], slot_count=each_instance_card_num + resource_type=node["type"], num=each_instance_card_num ) card_ids_str = ",".join(map(str, card_ids)) ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}" From 6ebfadd9557765b36bfb0c0a1ba1428b0ca5b389 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Wed, 16 Apr 2025 18:55:51 +0800 Subject: [PATCH 06/62] fix code --- examples/qwen/conf/serve/serve_qwen2.5_7b.yaml | 2 +- flagscale/runner/runner_serve.py | 17 +++++++---------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml b/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml index 387d29b07..691f88827 100644 --- a/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml +++ b/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml @@ -2,7 +2,7 @@ engine: vllm engine_args: model: /models/Qwen2.5-7B-Instruct - tensor_parallel_size: 1 + tensor_parallel_size: 2 pipeline_parallel_size: 1 gpu_memory_utilization: 0.9 max_model_len: 32768 diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 0e8296814..d781d8a1b 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -341,18 +341,11 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi ) node_cmd = f"pkill -f 'vllm serve'" - if before_start_cmd: - node_cmd = f"{before_start_cmd} && " + node_cmd - ssh_cmd = f'ssh -n -p {ssh_port} {ip} "{node_cmd}"' if docker_name: ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" f.write(f"{ssh_cmd}\n") - if before_start_cmd: - f.write(f"{before_start_cmd} && pkill -f 'vllm serve'\n") - else: - f.write(f"pkill -f 'vllm serve'\n") f.write("pkill -f 'run_inference_engine'\n") f.write("pkill -f 'run_fs_serve_vllm'\n") f.write("pkill -f 'vllm serve'\n") @@ -381,9 +374,10 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{p_kv_config_json}'\\''" if p_address != master_ip: - ssh_cmd = f'ssh -n -p {ssh_port} {p_address} "{node_cmd}"' if docker_name: ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" + else: + ssh_cmd = f'ssh -n -p {ssh_port} {d_address} "{node_cmd}"' f.write(f"{ssh_cmd}\n") else: f.write(f"{node_cmd}\n") @@ -408,12 +402,15 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}" d_kv_config_json = json.dumps(d_kv_config).replace('"', '\\"') - node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{d_kv_config_json}'\\''" if d_address != master_ip: - ssh_cmd = f'ssh -n -p {ssh_port} {d_address} "{node_cmd}"' + node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{d_kv_config_json}'\\''" if docker_name: ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" + else: + ssh_cmd = f'ssh -n -p {ssh_port} {d_address} "{node_cmd}"' + f.write(f"{ssh_cmd}\n") else: + node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '{d_kv_config_json}'" f.write(f"{node_cmd}\n") else: From c3c55d917c165237153eee4d5de6867dd42f7d46 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Wed, 16 Apr 2025 19:02:32 +0800 Subject: [PATCH 07/62] fix code --- flagscale/runner/runner_serve.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index d781d8a1b..245915d4d 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -372,14 +372,15 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi p_kv_config_json = json.dumps(p_kv_config).replace('"', '\\"') - node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{p_kv_config_json}'\\''" if p_address != master_ip: + node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{p_kv_config_json}'\\''" if docker_name: ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" else: ssh_cmd = f'ssh -n -p {ssh_port} {d_address} "{node_cmd}"' f.write(f"{ssh_cmd}\n") else: + node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '{p_kv_config_json}'" f.write(f"{node_cmd}\n") for _ in range(d_num): From 45f6d7f4068d4e40191c823fe2770745488b82f9 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Wed, 16 Apr 2025 19:12:19 +0800 Subject: [PATCH 08/62] fix code --- flagscale/runner/runner_serve.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 245915d4d..b0574bac3 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -322,8 +322,8 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi vllm_command = " ".join(command_items) if before_start_cmd: vllm_command = f"{before_start_cmd} && " + vllm_command - p_address = deploy_config.get("prefill_address", "127.0.0.1") - d_address = deploy_config.get("decode_address", "127.0.0.1") + p_address = deploy_config.get("prefill_address", "auto") + d_address = deploy_config.get("decode_address", "auto") tensor_parallel_size = deploy_config.get("tensor_parallel_size", 1) pipeline_parallel_size = deploy_config.get("pipeline_parallel_size", 1) each_instance_card_num = tensor_parallel_size * pipeline_parallel_size @@ -365,7 +365,8 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi }, } card_ids = resource_manager.get_available_card_ids( - resource_type=node["type"], num=each_instance_card_num + address=p_address, + num=each_instance_card_num, ) card_ids_str = ",".join(map(str, card_ids)) ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}" @@ -397,7 +398,8 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi }, } card_ids = resource_manager.get_available_card_ids( - resource_type=node["type"], num=each_instance_card_num + address=d_address, + num=each_instance_card_num, ) card_ids_str = ",".join(map(str, card_ids)) ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}" From dee02b2ec788de8ce9240bac227232130193b5e6 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Wed, 16 Apr 2025 19:17:27 +0800 Subject: [PATCH 09/62] fix code --- flagscale/runner/runner_serve.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index b0574bac3..6ba3437b6 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -324,8 +324,8 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi vllm_command = f"{before_start_cmd} && " + vllm_command p_address = deploy_config.get("prefill_address", "auto") d_address = deploy_config.get("decode_address", "auto") - tensor_parallel_size = deploy_config.get("tensor_parallel_size", 1) - pipeline_parallel_size = deploy_config.get("pipeline_parallel_size", 1) + tensor_parallel_size = engine_args.get("tensor_parallel_size", 1) + pipeline_parallel_size = engine_args.get("pipeline_parallel_size", 1) each_instance_card_num = tensor_parallel_size * pipeline_parallel_size f.write(f"# clean nodes \n") From ca48ee0b71413374992f87ad742403e648eaff18 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Wed, 16 Apr 2025 19:37:36 +0800 Subject: [PATCH 10/62] fix code --- flagscale/runner/runner_serve.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 6ba3437b6..86a565783 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -371,10 +371,11 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi card_ids_str = ",".join(map(str, card_ids)) ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}" - p_kv_config_json = json.dumps(p_kv_config).replace('"', '\\"') + p_kv_config_json = json.dumps(p_kv_config) if p_address != master_ip: - node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{p_kv_config_json}'\\''" + p_kv_config_formate_json = p_kv_config_json.replace('"', '\\"') + node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{p_kv_config_formate_json}'\\''" if docker_name: ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" else: @@ -404,9 +405,10 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi card_ids_str = ",".join(map(str, card_ids)) ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}" - d_kv_config_json = json.dumps(d_kv_config).replace('"', '\\"') + d_kv_config_json = json.dumps(d_kv_config) if d_address != master_ip: - node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{d_kv_config_json}'\\''" + d_kv_config_formate_json = d_kv_config_json.replace('"', '\\"') + node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{d_kv_config_formate_json}'\\''" if docker_name: ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" else: From 314d41cb2b38a6535ac86d5bc6a9530ff7c257cf Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Wed, 16 Apr 2025 23:55:54 +0800 Subject: [PATCH 11/62] fix code --- flagscale/runner/runner_serve.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 86a565783..67a97dd59 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -357,11 +357,11 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi p_kv_config = { "kv_connector": "P2pConnector", "kv_role": "kv_producer", - "kv_port": kv_port, + "kv_port": str(kv_port), "kv_connector_extra_config": { "proxy_ip": master_ip, - "proxy_port": kv_proxy_port, - "http_port": http_port, + "proxy_port": str(kv_proxy_port), + "http_port": str(http_port), }, } card_ids = resource_manager.get_available_card_ids( @@ -383,7 +383,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi f.write(f"{ssh_cmd}\n") else: node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '{p_kv_config_json}'" - f.write(f"{node_cmd}\n") + f.write(f"{node_cmd} &\n") for _ in range(d_num): kv_port = kv_related_ports.pop() @@ -391,11 +391,11 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi d_kv_config = { "kv_connector": "P2pConnector", "kv_role": "kv_consumer", - "kv_port": kv_port, + "kv_port": str(kv_port), "kv_connector_extra_config": { "proxy_ip": master_ip, - "proxy_port": kv_proxy_port, - "http_port": http_port, + "proxy_port": str(kv_proxy_port), + "http_port": str(http_port), }, } card_ids = resource_manager.get_available_card_ids( @@ -416,7 +416,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi f.write(f"{ssh_cmd}\n") else: node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '{d_kv_config_json}'" - f.write(f"{node_cmd}\n") + f.write(f"{node_cmd} &\n") else: f.write(f"ray_path=$(realpath $(which ray))\n") From 3ea878a5f3499dc61e418ec2d5c28bd7631dabf0 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 17 Apr 2025 00:38:32 +0800 Subject: [PATCH 12/62] fix code --- flagscale/runner/runner_serve.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 67a97dd59..0d1b151b8 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -320,6 +320,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi other_args = flatten_dict_to_args(engine_args, ["model", "port"]) command_items.extend(other_args) vllm_command = " ".join(command_items) + vllm_command = "nohup " + vllm_command if before_start_cmd: vllm_command = f"{before_start_cmd} && " + vllm_command p_address = deploy_config.get("prefill_address", "auto") @@ -327,6 +328,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi tensor_parallel_size = engine_args.get("tensor_parallel_size", 1) pipeline_parallel_size = engine_args.get("pipeline_parallel_size", 1) each_instance_card_num = tensor_parallel_size * pipeline_parallel_size + default_log_dir = "/tmp/flagscale" f.write(f"# clean nodes \n") if len(nodes) > 1: @@ -339,19 +341,21 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi raise ValueError( f"Number of slots must be specified for node {node}. This can be done by setting the 'slots' attribute." ) - node_cmd = f"pkill -f 'vllm serve'" + node_cmd = f"pkill -f vllm && mkdir -p {default_log_dir}" ssh_cmd = f'ssh -n -p {ssh_port} {ip} "{node_cmd}"' if docker_name: ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" f.write(f"{ssh_cmd}\n") + f.write("pkill -f 'run_inference_engine'\n") f.write("pkill -f 'run_fs_serve_vllm'\n") f.write("pkill -f 'vllm serve'\n") + f.write(f"mkdir -p {default_log_dir}\n") f.write(f"\n") - for _ in range(p_num): + for i in range(p_num): kv_port = kv_related_ports.pop() http_port = kv_related_ports.pop() p_kv_config = { @@ -372,17 +376,20 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}" p_kv_config_json = json.dumps(p_kv_config) + p_instance_log_path = os.path.join( + default_log_dir, f"prefill_{i}.log" + ) if p_address != master_ip: p_kv_config_formate_json = p_kv_config_json.replace('"', '\\"') node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{p_kv_config_formate_json}'\\''" if docker_name: - ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" + ssh_cmd = f"ssh -f -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd} > {p_instance_log_path} 2>&1 &'\"" else: - ssh_cmd = f'ssh -n -p {ssh_port} {d_address} "{node_cmd}"' + ssh_cmd = f'ssh -f -n -p {ssh_port} {d_address} "{node_cmd} > {p_instance_log_path} 2>&1 &"' f.write(f"{ssh_cmd}\n") else: - node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '{p_kv_config_json}'" + node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '{p_kv_config_json}' > {p_instance_log_path} 2>&1 &" f.write(f"{node_cmd} &\n") for _ in range(d_num): @@ -406,16 +413,20 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}" d_kv_config_json = json.dumps(d_kv_config) + d_instance_log_path = os.path.join( + default_log_dir, f"decode_{i}.log" + ) + if d_address != master_ip: d_kv_config_formate_json = d_kv_config_json.replace('"', '\\"') node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{d_kv_config_formate_json}'\\''" if docker_name: - ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" + ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd} > {d_instance_log_path} 2>&1 &'\"" else: - ssh_cmd = f'ssh -n -p {ssh_port} {d_address} "{node_cmd}"' + ssh_cmd = f'ssh -n -p {ssh_port} {d_address} "{node_cmd} > {d_instance_log_path} 2>&1 &"' f.write(f"{ssh_cmd}\n") else: - node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '{d_kv_config_json}'" + node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '{d_kv_config_json}' > {d_instance_log_path} 2>&1 &" f.write(f"{node_cmd} &\n") else: From 3e3b7e05588020c4346d20d80d5c63ad8d3167a8 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 17 Apr 2025 00:42:05 +0800 Subject: [PATCH 13/62] fix code --- flagscale/runner/runner_serve.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 0d1b151b8..a87690d1c 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -387,10 +387,10 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi ssh_cmd = f"ssh -f -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd} > {p_instance_log_path} 2>&1 &'\"" else: ssh_cmd = f'ssh -f -n -p {ssh_port} {d_address} "{node_cmd} > {p_instance_log_path} 2>&1 &"' - f.write(f"{ssh_cmd}\n") + f.write(f"{ssh_cmd}\n\n") else: node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '{p_kv_config_json}' > {p_instance_log_path} 2>&1 &" - f.write(f"{node_cmd} &\n") + f.write(f"{node_cmd}\n\n") for _ in range(d_num): kv_port = kv_related_ports.pop() @@ -424,10 +424,10 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd} > {d_instance_log_path} 2>&1 &'\"" else: ssh_cmd = f'ssh -n -p {ssh_port} {d_address} "{node_cmd} > {d_instance_log_path} 2>&1 &"' - f.write(f"{ssh_cmd}\n") + f.write(f"{ssh_cmd}\n\n") else: node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '{d_kv_config_json}' > {d_instance_log_path} 2>&1 &" - f.write(f"{node_cmd} &\n") + f.write(f"{node_cmd}\n\n") else: f.write(f"ray_path=$(realpath $(which ray))\n") From 29dcdd7714634db2c110751786eaad18c4806fef Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 17 Apr 2025 00:53:09 +0800 Subject: [PATCH 14/62] fix code --- flagscale/runner/runner_serve.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index a87690d1c..4edb54fa3 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -341,7 +341,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi raise ValueError( f"Number of slots must be specified for node {node}. This can be done by setting the 'slots' attribute." ) - node_cmd = f"pkill -f vllm && mkdir -p {default_log_dir}" + node_cmd = f"mkdir -p {default_log_dir} && pkill -f vllm" ssh_cmd = f'ssh -n -p {ssh_port} {ip} "{node_cmd}"' @@ -392,7 +392,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '{p_kv_config_json}' > {p_instance_log_path} 2>&1 &" f.write(f"{node_cmd}\n\n") - for _ in range(d_num): + for j in range(d_num): kv_port = kv_related_ports.pop() http_port = kv_related_ports.pop() d_kv_config = { @@ -414,7 +414,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi d_kv_config_json = json.dumps(d_kv_config) d_instance_log_path = os.path.join( - default_log_dir, f"decode_{i}.log" + default_log_dir, f"decode_{j}.log" ) if d_address != master_ip: From 85e6bb15b044ab534b5b303100bb9cbd37f4ce71 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 17 Apr 2025 00:58:29 +0800 Subject: [PATCH 15/62] fix code --- flagscale/runner/runner_serve.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 4edb54fa3..e9170fe52 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -313,6 +313,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi ports_num = (p_num + d_num) * 2 + 1 kv_related_ports = _get_multiple_free_ports(ports_num) kv_proxy_port = kv_related_ports.pop() + kv_proxy_port = 30001 # debug, tobe removed engine_args = _get_engine_args(config) command_items = ["vllm", "serve"] @@ -645,8 +646,12 @@ def _prepare(self): self.user_args = _get_args_vllm(self.config) self.user_envs = self.config.experiment.get("envs", {}) entrypoint = self.config.experiment.task.get("entrypoint", None) - if self.inference_engine: - if not self.use_fs_serve: + if self.inference_engine: # pd_disagg_router + if self.config.experiment.get("deploy", {}).get( + "prefill_decode_disaggregation", False + ): + self.user_script = "flagscale/serve/run_pd_disagg_router.py" + elif not self.use_fs_serve: self.user_script = "flagscale/serve/run_inference_engine.py" else: self.user_script = "flagscale/serve/run_fs_serve_vllm.py" From cf25c807dcac9ad681f77dad3eb09b22933b23f7 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 17 Apr 2025 01:08:29 +0800 Subject: [PATCH 16/62] fix code --- flagscale/serve/{pd_disagg_router.py => run_pd_disagg_router.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename flagscale/serve/{pd_disagg_router.py => run_pd_disagg_router.py} (100%) diff --git a/flagscale/serve/pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py similarity index 100% rename from flagscale/serve/pd_disagg_router.py rename to flagscale/serve/run_pd_disagg_router.py From aab9e5a047cac3bd4e373c5d2f0763f907826331 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 17 Apr 2025 01:12:21 +0800 Subject: [PATCH 17/62] fix code --- flagscale/runner/runner_serve.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index e9170fe52..b85751c65 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -422,9 +422,9 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi d_kv_config_formate_json = d_kv_config_json.replace('"', '\\"') node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{d_kv_config_formate_json}'\\''" if docker_name: - ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd} > {d_instance_log_path} 2>&1 &'\"" + ssh_cmd = f"ssh -f -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd} > {d_instance_log_path} 2>&1 &'\"" else: - ssh_cmd = f'ssh -n -p {ssh_port} {d_address} "{node_cmd} > {d_instance_log_path} 2>&1 &"' + ssh_cmd = f'ssh -f -n -p {ssh_port} {d_address} "{node_cmd} > {d_instance_log_path} 2>&1 &"' f.write(f"{ssh_cmd}\n\n") else: node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '{d_kv_config_json}' > {d_instance_log_path} 2>&1 &" From ec8f6fd7bc3738d772903fe425505c67513cdf9e Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 17 Apr 2025 01:18:06 +0800 Subject: [PATCH 18/62] fix code --- examples/qwen/conf/serve/serve_qwen2.5_7b.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml b/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml index 691f88827..31f2bf1c0 100644 --- a/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml +++ b/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml @@ -2,6 +2,7 @@ engine: vllm engine_args: model: /models/Qwen2.5-7B-Instruct + host: 0.0.0.0 tensor_parallel_size: 2 pipeline_parallel_size: 1 gpu_memory_utilization: 0.9 From 1e1bcb88363d461e8725704c9b54dae3c8cf3f9c Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 17 Apr 2025 01:33:23 +0800 Subject: [PATCH 19/62] fix code --- flagscale/serve/run_pd_disagg_router.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index f909edc3c..72429d411 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -149,7 +149,7 @@ async def handle_request(): @app.route("/v1/chat/completions", methods=["POST"]) -async def handle_request(): +async def handle_chat_request(): try: original_request_data = await request.get_json() @@ -164,7 +164,7 @@ async def handle_request(): list(prefill_instances.items()) ) print( - "handle_request, prefill_addr: %s, zmq_addr: %s", + "handle_chat_request, prefill_addr: %s, zmq_addr: %s", prefill_addr, prefill_zmq_addr, ) @@ -174,7 +174,7 @@ async def handle_request(): with decode_cv: decode_addr, decode_zmq_addr = random.choice(list(decode_instances.items())) print( - "handle_request, decode_addr: %s, zmq_addr: %s", + "handle_chat_request, decode_addr: %s, zmq_addr: %s", decode_addr, decode_zmq_addr, ) From 987a37c2fa93cdc3f6241d00c63545281a9eaf17 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 17 Apr 2025 10:16:00 +0800 Subject: [PATCH 20/62] fix code --- flagscale/runner/runner_serve.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index b85751c65..e05ca8374 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -353,6 +353,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi f.write("pkill -f 'run_inference_engine'\n") f.write("pkill -f 'run_fs_serve_vllm'\n") f.write("pkill -f 'vllm serve'\n") + f.write("pkill -f 'run_pd_disagg_router'\n") f.write(f"mkdir -p {default_log_dir}\n") f.write(f"\n") From b617ac23976531abc8bf75455c6d7c19b964b26a Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 17 Apr 2025 11:16:53 +0800 Subject: [PATCH 21/62] fix code --- flagscale/runner/runner_serve.py | 9 ++++++++- flagscale/serve/run_pd_disagg_router.py | 17 ++++++++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index e05ca8374..78ef365e7 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -329,7 +329,9 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi tensor_parallel_size = engine_args.get("tensor_parallel_size", 1) pipeline_parallel_size = engine_args.get("pipeline_parallel_size", 1) each_instance_card_num = tensor_parallel_size * pipeline_parallel_size - default_log_dir = "/tmp/flagscale" + default_log_dir = deploy_config.get( + "prefill_decode_log_dir", logging_config.log_dir + ) f.write(f"# clean nodes \n") if len(nodes) > 1: @@ -357,6 +359,8 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi f.write(f"mkdir -p {default_log_dir}\n") f.write(f"\n") + f.write("echo '=========== launch prefill instance ==========='\n") + for i in range(p_num): kv_port = kv_related_ports.pop() http_port = kv_related_ports.pop() @@ -394,6 +398,8 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '{p_kv_config_json}' > {p_instance_log_path} 2>&1 &" f.write(f"{node_cmd}\n\n") + f.write("echo '=========== launch decode instance ==========='\n") + for j in range(d_num): kv_port = kv_related_ports.pop() http_port = kv_related_ports.pop() @@ -566,6 +572,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi f.write(f"\n") # TODO: need a option to control whether to append or overwrite the output file # Now, it always appends to the output file + f.write("echo '=========== launch task ==========='\n") if background: f.write( f'nohup bash -c "$cmd; sync" >> {host_output_file} 2>&1 & echo $! > {host_pid_file}\n' diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index 72429d411..54ff826db 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -120,6 +120,14 @@ async def handle_request(): decode_addr, decode_zmq_addr, ) + print( + f"======== {prefill_zmq_addr} /v1/completions prefill_instances {prefill_instances} ========== ", + flush=True, + ) + print( + f"======== {decode_zmq_addr} /v1/completions decode_instances {decode_instances} ========== ", + flush=True, + ) request_id = f"___prefill_addr_{prefill_zmq_addr}___decode_addr_{decode_zmq_addr}_{random_uuid()}" @@ -178,7 +186,14 @@ async def handle_chat_request(): decode_addr, decode_zmq_addr, ) - + print( + f"======== {prefill_zmq_addr} /v1/chat/completions prefill_instances {prefill_instances} ========== ", + flush=True, + ) + print( + f"======== {decode_zmq_addr} /v1/chat/completions decode_instances {decode_instances} ========== ", + flush=True, + ) request_id = f"___prefill_addr_{prefill_zmq_addr}___decode_addr_{decode_zmq_addr}_{random_uuid()}" # finish prefill From ea0832b51bad505d0e7a70ff2ec3475c4c95d834 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 17 Apr 2025 11:56:29 +0800 Subject: [PATCH 22/62] fix code --- examples/qwen/conf/serve/serve_qwen2.5_7b.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml b/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml index 31f2bf1c0..71c72393f 100644 --- a/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml +++ b/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml @@ -1,11 +1,11 @@ - serve_id: vllm_model engine: vllm engine_args: - model: /models/Qwen2.5-7B-Instruct + model: /models/Qwen2.5-0.5B-Instruct host: 0.0.0.0 tensor_parallel_size: 2 pipeline_parallel_size: 1 - gpu_memory_utilization: 0.9 + gpu_memory_utilization: 0.1 max_model_len: 32768 max_num_seqs: 256 enforce_eager: true From eea51d28d3ee11b56795d0131f6f0a5c1877f492 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 17 Apr 2025 14:27:24 +0800 Subject: [PATCH 23/62] fix code --- .../qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml | 2 +- examples/qwen/conf/serve/serve_qwen2.5_7b.yaml | 2 +- flagscale/runner/runner_serve.py | 8 ++++++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml index 25b8144b4..51d774ae7 100644 --- a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml +++ b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml @@ -13,7 +13,7 @@ experiment: prefill_decode_disaggregation: true prefill_num: 1 prefill_address: 10.1.1.122 - decode_num: 2 + decode_num: 1 decode_address: 10.1.1.108 runner: hostfile: examples/deepseek/conf/hostfile.txt diff --git a/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml b/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml index 71c72393f..3a7a20130 100644 --- a/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml +++ b/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml @@ -3,7 +3,7 @@ engine_args: model: /models/Qwen2.5-0.5B-Instruct host: 0.0.0.0 - tensor_parallel_size: 2 + tensor_parallel_size: 1 pipeline_parallel_size: 1 gpu_memory_utilization: 0.1 max_model_len: 32768 diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 78ef365e7..53247e99f 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -374,6 +374,10 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi "http_port": str(http_port), }, } + print( + f"============= prefill instance {i}, p_kv_config: {p_kv_config} =============", + flush=True, + ) card_ids = resource_manager.get_available_card_ids( address=p_address, num=each_instance_card_num, @@ -413,6 +417,10 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi "http_port": str(http_port), }, } + print( + f"============= decode instance {i}, d_kv_config: {d_kv_config} =============", + flush=True, + ) card_ids = resource_manager.get_available_card_ids( address=d_address, num=each_instance_card_num, From 8540109d0e0a7c927e661671fee58f0cd9571326 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 17 Apr 2025 15:43:53 +0800 Subject: [PATCH 24/62] fix code --- flagscale/serve/run_pd_disagg_router.py | 303 ++++++++++-------------- 1 file changed, 122 insertions(+), 181 deletions(-) diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index 54ff826db..c27a44067 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -1,229 +1,170 @@ +from __future__ import annotations + +import asyncio import os import random import socket import threading import uuid +from enum import Enum +from typing import Dict, Literal, Optional, Tuple import aiohttp import msgpack import zmq from quart import Quart, make_response, request -prefill_instances: dict[str, str] = {} # http_address: zmq_address -decode_instances: dict[str, str] = {} # http_address: zmq_address -prefill_cv = threading.Condition() -decode_cv = threading.Condition() +# ─── Registry ────────────────────────────────────────────────────────────────── +class InstanceType(Enum): + PREFILL = "P" + DECODE = "D" + + +class ServiceRegistry: + """Thread‑safe container for prefill / decode instance metadata.""" + def __init__(self) -> None: + self._lock = threading.RLock() + self._cond_prefill = threading.Condition(self._lock) + self._cond_decode = threading.Condition(self._lock) + self._instances: Dict[InstanceType, Dict[str, str]] = { + InstanceType.PREFILL: {}, + InstanceType.DECODE: {}, + } + + # ---- public API ---------------------------------------------------------- + def register( + self, + itype: InstanceType | Literal["P", "D"], + http_addr: str, + zmq_addr: str, + ) -> None: + itype = InstanceType(itype) # cast if literal + with self._lock: + self._instances[itype][http_addr] = zmq_addr + # wake one waiter if someone is blocked on this pool + cond = ( + self._cond_prefill + if itype is InstanceType.PREFILL + else self._cond_decode + ) + cond.notify() -def _listen_for_register(poller, router_socket): + def random_instance(self, itype: InstanceType) -> Tuple[str, str]: + """Blocks until at least one instance of *itype* is present.""" + cond = ( + self._cond_prefill if itype is InstanceType.PREFILL else self._cond_decode + ) + with cond: + while not self._instances[itype]: + cond.wait() + http, zmq_ = random.choice(list(self._instances[itype].items())) + return http, zmq_ + + def size(self, itype: InstanceType) -> int: + with self._lock: + return len(self._instances[itype]) + + +# ─── ZMQ listener ------------------------------------------------------------- +def _listen_for_register(registry: ServiceRegistry, router_socket, poller) -> None: while True: socks = dict(poller.poll()) if router_socket in socks: - remote_address, message = router_socket.recv_multipart() - # data: {"type": "P", "http_address": "ip:port", - # "zmq_address": "ip:port"} - data = msgpack.loads(message) - # print("Received message from %s, data: %s", - # remote_address.decode(), data) - if data["type"] == "P": - global prefill_instances - global prefill_cv - with prefill_cv: - prefill_instances[data["http_address"]] = data["zmq_address"] - elif data["type"] == "D": - global decode_instances - global decode_cv - with decode_cv: - decode_instances[data["http_address"]] = data["zmq_address"] - else: - print( - "Unexpected, Received message from %s, data: %s", - remote_address, - data, + _, message = router_socket.recv_multipart() + data = msgpack.loads( + message + ) # {"type":"P","http_address":..,"zmq_address":..} + try: + registry.register( + data["type"], data["http_address"], data["zmq_address"] ) + except (KeyError, ValueError): + print("⚠️ malformed registration data:", data) -def start_service_discovery(hostname, port): - if not hostname: - hostname = socket.gethostname() +def start_service_discovery( + registry: ServiceRegistry, hostname: str, port: int +) -> threading.Thread: if port == 0: raise ValueError("Port cannot be 0") + if not hostname: + hostname = socket.gethostname() - context = zmq.Context() - router_socket = context.socket(zmq.ROUTER) - router_socket.bind(f"tcp://{hostname}:{port}") + ctx, router = zmq.Context(), zmq.Context().socket(zmq.ROUTER) + router.bind(f"tcp://{hostname}:{port}") poller = zmq.Poller() - poller.register(router_socket, zmq.POLLIN) - - _listener_thread = threading.Thread( - target=_listen_for_register, args=[poller, router_socket], daemon=True + poller.register(router, zmq.POLLIN) + t = threading.Thread( + target=_listen_for_register, daemon=True, args=(registry, router, poller) ) - _listener_thread.start() - return _listener_thread + t.start() + return t +# ─── HTTP proxy server -------------------------------------------------------- AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) - app = Quart(__name__) +registry = ServiceRegistry() # <── NEW single global -def random_uuid() -> str: - return str(uuid.uuid4().hex) +def _uuid() -> str: + return uuid.uuid4().hex -async def forward_request(url, data, request_id): - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: +async def forward_request(url: str, data: dict, request_id: str): + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as sess: headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY', '')}", "X-Request-Id": request_id, } - async with session.post(url=url, json=data, headers=headers) as response: - if response.status == 200: - # if response.headers.get('Transfer-Encoding') == 'chunked': - if True: - async for chunk_bytes in response.content.iter_chunked(1024): - yield chunk_bytes - else: - content = await response.read() - yield content - - -@app.route("/v1/completions", methods=["POST"]) -async def handle_request(): - try: - original_request_data = await request.get_json() - - prefill_request = original_request_data.copy() - # change max_tokens = 1 to let it only do prefill - prefill_request["max_tokens"] = 1 - - global prefill_instances - global prefill_cv - with prefill_cv: - prefill_addr, prefill_zmq_addr = random.choice( - list(prefill_instances.items()) - ) - print( - "handle_request, prefill_addr: %s, zmq_addr: %s", - prefill_addr, - prefill_zmq_addr, - ) - - global decode_instances - global decode_cv - with decode_cv: - decode_addr, decode_zmq_addr = random.choice(list(decode_instances.items())) - print( - "handle_request, decode_addr: %s, zmq_addr: %s", - decode_addr, - decode_zmq_addr, - ) - print( - f"======== {prefill_zmq_addr} /v1/completions prefill_instances {prefill_instances} ========== ", - flush=True, - ) - print( - f"======== {decode_zmq_addr} /v1/completions decode_instances {decode_instances} ========== ", - flush=True, - ) - - request_id = f"___prefill_addr_{prefill_zmq_addr}___decode_addr_{decode_zmq_addr}_{random_uuid()}" - - # finish prefill - async for _ in forward_request( - f"http://{prefill_addr}/v1/completions", prefill_request, request_id - ): - continue - - # return decode - generator = forward_request( - f"http://{decode_addr}/v1/completions", original_request_data, request_id - ) - response = await make_response(generator) - response.timeout = None - - return response - - except Exception as e: - import sys - import traceback - - exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server") - print(e) - print("".join(traceback.format_exception(*exc_info))) + async with sess.post(url, json=data, headers=headers) as resp: + if resp.status == 200: + async for chunk in resp.content.iter_chunked(1024): + yield chunk + else: + raise RuntimeError(f"Upstream {url} returned {resp.status}") -@app.route("/v1/chat/completions", methods=["POST"]) -async def handle_chat_request(): - try: - original_request_data = await request.get_json() +async def _handle_common(original_request: dict, api_path: str): + # pick instances + pre_http, pre_zmq = registry.random_instance(InstanceType.PREFILL) + dec_http, dec_zmq = registry.random_instance(InstanceType.DECODE) + request_id = f"___prefill_{pre_zmq}___decode_{dec_zmq}___{_uuid()}" - prefill_request = original_request_data.copy() - # change max_tokens = 1 to let it only do prefill - prefill_request["max_tokens"] = 1 + # 1️⃣ prefill: max_tokens = 1 + prefill_request = {**original_request, "max_tokens": 1} + async for _ in forward_request( + f"http://{pre_http}{api_path}", prefill_request, request_id + ): + continue - global prefill_instances - global prefill_cv - with prefill_cv: - prefill_addr, prefill_zmq_addr = random.choice( - list(prefill_instances.items()) - ) - print( - "handle_chat_request, prefill_addr: %s, zmq_addr: %s", - prefill_addr, - prefill_zmq_addr, - ) + # 2️⃣ decode: stream back to client + generator = forward_request( + f"http://{dec_http}{api_path}", original_request, request_id + ) + resp = await make_response(generator) + resp.timeout = None + return resp - global decode_instances - global decode_cv - with decode_cv: - decode_addr, decode_zmq_addr = random.choice(list(decode_instances.items())) - print( - "handle_chat_request, decode_addr: %s, zmq_addr: %s", - decode_addr, - decode_zmq_addr, - ) - print( - f"======== {prefill_zmq_addr} /v1/chat/completions prefill_instances {prefill_instances} ========== ", - flush=True, - ) - print( - f"======== {decode_zmq_addr} /v1/chat/completions decode_instances {decode_instances} ========== ", - flush=True, - ) - request_id = f"___prefill_addr_{prefill_zmq_addr}___decode_addr_{decode_zmq_addr}_{random_uuid()}" - - # finish prefill - async for _ in forward_request( - f"http://{prefill_addr}/v1/chat/completions", prefill_request, request_id - ): - continue - - # return decode - generator = forward_request( - f"http://{decode_addr}/v1/chat/completions", - original_request_data, - request_id, - ) - response = await make_response(generator) - response.timeout = None - return response +@app.post("/v1/completions") +async def handle_request(): # legacy openai completions + return await _handle_common(await request.get_json(), "/v1/completions") - except Exception as e: - import sys - import traceback - exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server") - print(e) - print("".join(traceback.format_exception(*exc_info))) +@app.post("/v1/chat/completions") +async def handle_chat_request(): # chat completions + return await _handle_common(await request.get_json(), "/v1/chat/completions") +# ─── main ───────────────────────────────────────────────────────────────────── if __name__ == "__main__": - t = start_service_discovery("0.0.0.0", 30001) - app.run(host="0.0.0.0", port=10001) - t.join() + discovery_thread = start_service_discovery(registry, "0.0.0.0", 30001) + try: + # Quart uses asyncio, so run() is non‑blocking once the event loop starts. + app.run(host="0.0.0.0", port=10001) + finally: + discovery_thread.join() From f3e00a6220d740512b7818c1bd4ea6bbd77987a8 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 17 Apr 2025 16:03:13 +0800 Subject: [PATCH 25/62] fix code --- flagscale/serve/run_pd_disagg_router.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index c27a44067..cb93448ea 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -60,6 +60,10 @@ def random_instance(self, itype: InstanceType) -> Tuple[str, str]: while not self._instances[itype]: cond.wait() http, zmq_ = random.choice(list(self._instances[itype].items())) + print( + f"============== instance type: {itype} self._instances {self._instances} =================", + flush=True, + ) return http, zmq_ def size(self, itype: InstanceType) -> int: From 19ba091bab1d6376ec20d695bacc2c13600ff225 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 17 Apr 2025 16:32:05 +0800 Subject: [PATCH 26/62] fix code --- flagscale/serve/run_pd_disagg_router.py | 307 ++++++++++++++---------- 1 file changed, 181 insertions(+), 126 deletions(-) diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index cb93448ea..54ff826db 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -1,174 +1,229 @@ -from __future__ import annotations - -import asyncio import os import random import socket import threading import uuid -from enum import Enum -from typing import Dict, Literal, Optional, Tuple import aiohttp import msgpack import zmq from quart import Quart, make_response, request +prefill_instances: dict[str, str] = {} # http_address: zmq_address +decode_instances: dict[str, str] = {} # http_address: zmq_address -# ─── Registry ────────────────────────────────────────────────────────────────── -class InstanceType(Enum): - PREFILL = "P" - DECODE = "D" - +prefill_cv = threading.Condition() +decode_cv = threading.Condition() -class ServiceRegistry: - """Thread‑safe container for prefill / decode instance metadata.""" - - def __init__(self) -> None: - self._lock = threading.RLock() - self._cond_prefill = threading.Condition(self._lock) - self._cond_decode = threading.Condition(self._lock) - self._instances: Dict[InstanceType, Dict[str, str]] = { - InstanceType.PREFILL: {}, - InstanceType.DECODE: {}, - } - # ---- public API ---------------------------------------------------------- - def register( - self, - itype: InstanceType | Literal["P", "D"], - http_addr: str, - zmq_addr: str, - ) -> None: - itype = InstanceType(itype) # cast if literal - with self._lock: - self._instances[itype][http_addr] = zmq_addr - # wake one waiter if someone is blocked on this pool - cond = ( - self._cond_prefill - if itype is InstanceType.PREFILL - else self._cond_decode - ) - cond.notify() - - def random_instance(self, itype: InstanceType) -> Tuple[str, str]: - """Blocks until at least one instance of *itype* is present.""" - cond = ( - self._cond_prefill if itype is InstanceType.PREFILL else self._cond_decode - ) - with cond: - while not self._instances[itype]: - cond.wait() - http, zmq_ = random.choice(list(self._instances[itype].items())) - print( - f"============== instance type: {itype} self._instances {self._instances} =================", - flush=True, - ) - return http, zmq_ - - def size(self, itype: InstanceType) -> int: - with self._lock: - return len(self._instances[itype]) - - -# ─── ZMQ listener ------------------------------------------------------------- -def _listen_for_register(registry: ServiceRegistry, router_socket, poller) -> None: +def _listen_for_register(poller, router_socket): while True: socks = dict(poller.poll()) if router_socket in socks: - _, message = router_socket.recv_multipart() - data = msgpack.loads( - message - ) # {"type":"P","http_address":..,"zmq_address":..} - try: - registry.register( - data["type"], data["http_address"], data["zmq_address"] + remote_address, message = router_socket.recv_multipart() + # data: {"type": "P", "http_address": "ip:port", + # "zmq_address": "ip:port"} + data = msgpack.loads(message) + # print("Received message from %s, data: %s", + # remote_address.decode(), data) + if data["type"] == "P": + global prefill_instances + global prefill_cv + with prefill_cv: + prefill_instances[data["http_address"]] = data["zmq_address"] + elif data["type"] == "D": + global decode_instances + global decode_cv + with decode_cv: + decode_instances[data["http_address"]] = data["zmq_address"] + else: + print( + "Unexpected, Received message from %s, data: %s", + remote_address, + data, ) - except (KeyError, ValueError): - print("⚠️ malformed registration data:", data) -def start_service_discovery( - registry: ServiceRegistry, hostname: str, port: int -) -> threading.Thread: - if port == 0: - raise ValueError("Port cannot be 0") +def start_service_discovery(hostname, port): if not hostname: hostname = socket.gethostname() + if port == 0: + raise ValueError("Port cannot be 0") - ctx, router = zmq.Context(), zmq.Context().socket(zmq.ROUTER) - router.bind(f"tcp://{hostname}:{port}") + context = zmq.Context() + router_socket = context.socket(zmq.ROUTER) + router_socket.bind(f"tcp://{hostname}:{port}") poller = zmq.Poller() - poller.register(router, zmq.POLLIN) - t = threading.Thread( - target=_listen_for_register, daemon=True, args=(registry, router, poller) + poller.register(router_socket, zmq.POLLIN) + + _listener_thread = threading.Thread( + target=_listen_for_register, args=[poller, router_socket], daemon=True ) - t.start() - return t + _listener_thread.start() + return _listener_thread -# ─── HTTP proxy server -------------------------------------------------------- AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + app = Quart(__name__) -registry = ServiceRegistry() # <── NEW single global -def _uuid() -> str: - return uuid.uuid4().hex +def random_uuid() -> str: + return str(uuid.uuid4().hex) -async def forward_request(url: str, data: dict, request_id: str): - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as sess: +async def forward_request(url, data, request_id): + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: headers = { - "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY', '')}", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id, } - async with sess.post(url, json=data, headers=headers) as resp: - if resp.status == 200: - async for chunk in resp.content.iter_chunked(1024): - yield chunk - else: - raise RuntimeError(f"Upstream {url} returned {resp.status}") + async with session.post(url=url, json=data, headers=headers) as response: + if response.status == 200: + # if response.headers.get('Transfer-Encoding') == 'chunked': + if True: + async for chunk_bytes in response.content.iter_chunked(1024): + yield chunk_bytes + else: + content = await response.read() + yield content + + +@app.route("/v1/completions", methods=["POST"]) +async def handle_request(): + try: + original_request_data = await request.get_json() + prefill_request = original_request_data.copy() + # change max_tokens = 1 to let it only do prefill + prefill_request["max_tokens"] = 1 -async def _handle_common(original_request: dict, api_path: str): - # pick instances - pre_http, pre_zmq = registry.random_instance(InstanceType.PREFILL) - dec_http, dec_zmq = registry.random_instance(InstanceType.DECODE) - request_id = f"___prefill_{pre_zmq}___decode_{dec_zmq}___{_uuid()}" + global prefill_instances + global prefill_cv + with prefill_cv: + prefill_addr, prefill_zmq_addr = random.choice( + list(prefill_instances.items()) + ) + print( + "handle_request, prefill_addr: %s, zmq_addr: %s", + prefill_addr, + prefill_zmq_addr, + ) - # 1️⃣ prefill: max_tokens = 1 - prefill_request = {**original_request, "max_tokens": 1} - async for _ in forward_request( - f"http://{pre_http}{api_path}", prefill_request, request_id - ): - continue + global decode_instances + global decode_cv + with decode_cv: + decode_addr, decode_zmq_addr = random.choice(list(decode_instances.items())) + print( + "handle_request, decode_addr: %s, zmq_addr: %s", + decode_addr, + decode_zmq_addr, + ) + print( + f"======== {prefill_zmq_addr} /v1/completions prefill_instances {prefill_instances} ========== ", + flush=True, + ) + print( + f"======== {decode_zmq_addr} /v1/completions decode_instances {decode_instances} ========== ", + flush=True, + ) - # 2️⃣ decode: stream back to client - generator = forward_request( - f"http://{dec_http}{api_path}", original_request, request_id - ) - resp = await make_response(generator) - resp.timeout = None - return resp + request_id = f"___prefill_addr_{prefill_zmq_addr}___decode_addr_{decode_zmq_addr}_{random_uuid()}" + # finish prefill + async for _ in forward_request( + f"http://{prefill_addr}/v1/completions", prefill_request, request_id + ): + continue -@app.post("/v1/completions") -async def handle_request(): # legacy openai completions - return await _handle_common(await request.get_json(), "/v1/completions") + # return decode + generator = forward_request( + f"http://{decode_addr}/v1/completions", original_request_data, request_id + ) + response = await make_response(generator) + response.timeout = None + return response -@app.post("/v1/chat/completions") -async def handle_chat_request(): # chat completions - return await _handle_common(await request.get_json(), "/v1/chat/completions") + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server") + print(e) + print("".join(traceback.format_exception(*exc_info))) -# ─── main ───────────────────────────────────────────────────────────────────── -if __name__ == "__main__": - discovery_thread = start_service_discovery(registry, "0.0.0.0", 30001) + +@app.route("/v1/chat/completions", methods=["POST"]) +async def handle_chat_request(): try: - # Quart uses asyncio, so run() is non‑blocking once the event loop starts. - app.run(host="0.0.0.0", port=10001) - finally: - discovery_thread.join() + original_request_data = await request.get_json() + + prefill_request = original_request_data.copy() + # change max_tokens = 1 to let it only do prefill + prefill_request["max_tokens"] = 1 + + global prefill_instances + global prefill_cv + with prefill_cv: + prefill_addr, prefill_zmq_addr = random.choice( + list(prefill_instances.items()) + ) + print( + "handle_chat_request, prefill_addr: %s, zmq_addr: %s", + prefill_addr, + prefill_zmq_addr, + ) + + global decode_instances + global decode_cv + with decode_cv: + decode_addr, decode_zmq_addr = random.choice(list(decode_instances.items())) + print( + "handle_chat_request, decode_addr: %s, zmq_addr: %s", + decode_addr, + decode_zmq_addr, + ) + print( + f"======== {prefill_zmq_addr} /v1/chat/completions prefill_instances {prefill_instances} ========== ", + flush=True, + ) + print( + f"======== {decode_zmq_addr} /v1/chat/completions decode_instances {decode_instances} ========== ", + flush=True, + ) + request_id = f"___prefill_addr_{prefill_zmq_addr}___decode_addr_{decode_zmq_addr}_{random_uuid()}" + + # finish prefill + async for _ in forward_request( + f"http://{prefill_addr}/v1/chat/completions", prefill_request, request_id + ): + continue + + # return decode + generator = forward_request( + f"http://{decode_addr}/v1/chat/completions", + original_request_data, + request_id, + ) + response = await make_response(generator) + response.timeout = None + + return response + + except Exception as e: + import sys + import traceback + + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server") + print(e) + print("".join(traceback.format_exception(*exc_info))) + + +if __name__ == "__main__": + t = start_service_discovery("0.0.0.0", 30001) + app.run(host="0.0.0.0", port=10001) + t.join() From 2e8708f50f09ac9a0c34ca43042918ddb79db3e7 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 17 Apr 2025 17:20:19 +0800 Subject: [PATCH 27/62] fix code --- flagscale/serve/run_pd_disagg_router.py | 287 +++++++++++------------- 1 file changed, 129 insertions(+), 158 deletions(-) diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index 54ff826db..9ca459ce0 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -1,3 +1,4 @@ +import asyncio import os import random import socket @@ -9,8 +10,44 @@ import zmq from quart import Quart, make_response, request -prefill_instances: dict[str, str] = {} # http_address: zmq_address -decode_instances: dict[str, str] = {} # http_address: zmq_address + +# ─── Load Manager ──────────────────────────────────────────────────────────── +class LoadManager: + """Track number of in-flight tasks per instance and pick the least-loaded one.""" + + def __init__(self): + self._lock = threading.Lock() + self._load: dict[str, int] = {} + + def register(self, addr: str): + """Ensure this addr is known, with zero initial load.""" + with self._lock: + self._load.setdefault(addr, 0) + + def acquire(self) -> str: + """Pick the addr with minimal load and bump its count.""" + with self._lock: + if not self._load: + raise RuntimeError("No instances registered") + # find instance with smallest load + addr = min(self._load.items(), key=lambda kv: kv[1])[0] + self._load[addr] += 1 + return addr + + def release(self, addr: str): + """Decrement the load count for this addr.""" + with self._lock: + if addr in self._load and self._load[addr] > 0: + self._load[addr] -= 1 + + +# managers for prefill and decode pools +prefill_load_manager = LoadManager() +decode_load_manager = LoadManager() + +# ─── Service Discovery ─────────────────────────────────────────────────────── +prefill_instances: dict[str, str] = {} # http_address -> zmq_address +decode_instances: dict[str, str] = {} # http_address -> zmq_address prefill_cv = threading.Condition() decode_cv = threading.Condition() @@ -21,27 +58,19 @@ def _listen_for_register(poller, router_socket): socks = dict(poller.poll()) if router_socket in socks: remote_address, message = router_socket.recv_multipart() - # data: {"type": "P", "http_address": "ip:port", - # "zmq_address": "ip:port"} data = msgpack.loads(message) - # print("Received message from %s, data: %s", - # remote_address.decode(), data) + addr = data["http_address"] + zmq_addr = data["zmq_address"] if data["type"] == "P": - global prefill_instances - global prefill_cv with prefill_cv: - prefill_instances[data["http_address"]] = data["zmq_address"] + prefill_instances[addr] = zmq_addr + prefill_load_manager.register(addr) elif data["type"] == "D": - global decode_instances - global decode_cv with decode_cv: - decode_instances[data["http_address"]] = data["zmq_address"] + decode_instances[addr] = zmq_addr + decode_load_manager.register(addr) else: - print( - "Unexpected, Received message from %s, data: %s", - remote_address, - data, - ) + print(f"Unexpected message type from {remote_address}: {data}") def start_service_discovery(hostname, port): @@ -57,20 +86,20 @@ def start_service_discovery(hostname, port): poller = zmq.Poller() poller.register(router_socket, zmq.POLLIN) - _listener_thread = threading.Thread( + listener = threading.Thread( target=_listen_for_register, args=[poller, router_socket], daemon=True ) - _listener_thread.start() - return _listener_thread + listener.start() + return listener +# ─── HTTP Proxy ─────────────────────────────────────────────────────────────── AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) - app = Quart(__name__) def random_uuid() -> str: - return str(uuid.uuid4().hex) + return uuid.uuid4().hex async def forward_request(url, data, request_id): @@ -79,151 +108,93 @@ async def forward_request(url, data, request_id): "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id, } - async with session.post(url=url, json=data, headers=headers) as response: - if response.status == 200: - # if response.headers.get('Transfer-Encoding') == 'chunked': - if True: - async for chunk_bytes in response.content.iter_chunked(1024): - yield chunk_bytes - else: - content = await response.read() - yield content + async with session.post(url=url, json=data, headers=headers) as resp: + resp.raise_for_status() + async for chunk in resp.content.iter_chunked(1024): + yield chunk + + +async def _stream_and_release(gen, manager: LoadManager, addr: str): + try: + async for chunk in gen: + yield chunk + finally: + manager.release(addr) @app.route("/v1/completions", methods=["POST"]) async def handle_request(): - try: - original_request_data = await request.get_json() - - prefill_request = original_request_data.copy() - # change max_tokens = 1 to let it only do prefill - prefill_request["max_tokens"] = 1 - - global prefill_instances - global prefill_cv - with prefill_cv: - prefill_addr, prefill_zmq_addr = random.choice( - list(prefill_instances.items()) - ) - print( - "handle_request, prefill_addr: %s, zmq_addr: %s", - prefill_addr, - prefill_zmq_addr, - ) - - global decode_instances - global decode_cv - with decode_cv: - decode_addr, decode_zmq_addr = random.choice(list(decode_instances.items())) - print( - "handle_request, decode_addr: %s, zmq_addr: %s", - decode_addr, - decode_zmq_addr, - ) - print( - f"======== {prefill_zmq_addr} /v1/completions prefill_instances {prefill_instances} ========== ", - flush=True, - ) - print( - f"======== {decode_zmq_addr} /v1/completions decode_instances {decode_instances} ========== ", - flush=True, - ) - - request_id = f"___prefill_addr_{prefill_zmq_addr}___decode_addr_{decode_zmq_addr}_{random_uuid()}" - - # finish prefill - async for _ in forward_request( - f"http://{prefill_addr}/v1/completions", prefill_request, request_id - ): - continue - - # return decode - generator = forward_request( - f"http://{decode_addr}/v1/completions", original_request_data, request_id - ) - response = await make_response(generator) - response.timeout = None - - return response - - except Exception as e: - import sys - import traceback - - exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server") - print(e) - print("".join(traceback.format_exception(*exc_info))) + original = await request.get_json() + prefill_data = original.copy() + prefill_data["max_tokens"] = 1 + + # pick least-loaded prefill + with prefill_cv: + prefill_addr = prefill_load_manager.acquire() + prefill_zmq = prefill_instances[prefill_addr] + print(f"Selected prefill {prefill_addr} (load bumped)") + + # finish prefill stage + prefill_req_id = f"pre_{random_uuid()}" + async for _ in forward_request( + f"http://{prefill_addr}/v1/completions", prefill_data, prefill_req_id + ): + pass + # release prefill slot + prefill_load_manager.release(prefill_addr) + + # pick least-loaded decode + with decode_cv: + decode_addr = decode_load_manager.acquire() + decode_zmq = decode_instances[decode_addr] + print(f"Selected decode {decode_addr} (load bumped)") + + # stream decode back to client, releasing when done + decode_req_id = f"dec_{random_uuid()}" + decoder = forward_request( + f"http://{decode_addr}/v1/completions", original, decode_req_id + ) + wrapped = _stream_and_release(decoder, decode_load_manager, decode_addr) + response = await make_response(wrapped) + response.timeout = None + return response @app.route("/v1/chat/completions", methods=["POST"]) async def handle_chat_request(): - try: - original_request_data = await request.get_json() - - prefill_request = original_request_data.copy() - # change max_tokens = 1 to let it only do prefill - prefill_request["max_tokens"] = 1 - - global prefill_instances - global prefill_cv - with prefill_cv: - prefill_addr, prefill_zmq_addr = random.choice( - list(prefill_instances.items()) - ) - print( - "handle_chat_request, prefill_addr: %s, zmq_addr: %s", - prefill_addr, - prefill_zmq_addr, - ) - - global decode_instances - global decode_cv - with decode_cv: - decode_addr, decode_zmq_addr = random.choice(list(decode_instances.items())) - print( - "handle_chat_request, decode_addr: %s, zmq_addr: %s", - decode_addr, - decode_zmq_addr, - ) - print( - f"======== {prefill_zmq_addr} /v1/chat/completions prefill_instances {prefill_instances} ========== ", - flush=True, - ) - print( - f"======== {decode_zmq_addr} /v1/chat/completions decode_instances {decode_instances} ========== ", - flush=True, - ) - request_id = f"___prefill_addr_{prefill_zmq_addr}___decode_addr_{decode_zmq_addr}_{random_uuid()}" - - # finish prefill - async for _ in forward_request( - f"http://{prefill_addr}/v1/chat/completions", prefill_request, request_id - ): - continue - - # return decode - generator = forward_request( - f"http://{decode_addr}/v1/chat/completions", - original_request_data, - request_id, - ) - response = await make_response(generator) - response.timeout = None - - return response - - except Exception as e: - import sys - import traceback - - exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server") - print(e) - print("".join(traceback.format_exception(*exc_info))) + original = await request.get_json() + prefill_data = original.copy() + prefill_data["max_tokens"] = 1 + + with prefill_cv: + prefill_addr = prefill_load_manager.acquire() + prefill_zmq = prefill_instances[prefill_addr] + print(f"Selected prefill(chat) {prefill_addr}") + + prefill_req_id = f"pre_chat_{random_uuid()}" + async for _ in forward_request( + f"http://{prefill_addr}/v1/chat/completions", + prefill_data, + prefill_req_id, + ): + pass + prefill_load_manager.release(prefill_addr) + + with decode_cv: + decode_addr = decode_load_manager.acquire() + decode_zmq = decode_instances[decode_addr] + print(f"Selected decode(chat) {decode_addr}") + + decode_req_id = f"dec_chat_{random_uuid()}" + decoder = forward_request( + f"http://{decode_addr}/v1/chat/completions", original, decode_req_id + ) + wrapped = _stream_and_release(decoder, decode_load_manager, decode_addr) + response = await make_response(wrapped) + response.timeout = None + return response if __name__ == "__main__": - t = start_service_discovery("0.0.0.0", 30001) + start_service_discovery("0.0.0.0", 30001) app.run(host="0.0.0.0", port=10001) - t.join() From 0bc464a77179e711500b3d72c3145e4f3979b1bc Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 17 Apr 2025 19:09:29 +0800 Subject: [PATCH 28/62] fix code --- flagscale/serve/run_pd_disagg_router.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index 9ca459ce0..b96417108 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -23,9 +23,17 @@ def register(self, addr: str): """Ensure this addr is known, with zero initial load.""" with self._lock: self._load.setdefault(addr, 0) + print( + f"register-------------- self._load {self._load} -----------------", + flush=True, + ) def acquire(self) -> str: """Pick the addr with minimal load and bump its count.""" + print( + f"acquire-------------- self._load {self._load} -----------------", + flush=True, + ) with self._lock: if not self._load: raise RuntimeError("No instances registered") From abb6e930c40acd1852d9aba02cf28a5e4117c74b Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Fri, 18 Apr 2025 00:05:39 +0800 Subject: [PATCH 29/62] fix code --- examples/qwen/conf/target.sh | 2 +- flagscale/serve/run_pd_disagg_router.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/qwen/conf/target.sh b/examples/qwen/conf/target.sh index abbc609a2..aef7b1dad 100644 --- a/examples/qwen/conf/target.sh +++ b/examples/qwen/conf/target.sh @@ -25,7 +25,7 @@ pkill -f 'vllm serve' source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 && ${ray_path} start --head --port=59081 --num-gpus=8 # worker nodes -ssh -n -p 22 10.1.1.108 "docker exec ds /bin/bash -c 'source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 && ${ray_path} start --address=10.1.1.122:59081 --num-gpus=8'" +ssh -n -p 22 10.1.1.108 "docker exec ds /bin/bash -c 'source /root/miniconda3/bin/activate flagscale-inference && export CUDA_VISIBLE_DEVICES=7 && ${ray_path} start --address=10.1.1.122:59081 --num-gpus=8'" mkdir -p /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs mkdir -p /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs/pids diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index b96417108..4ff4b8010 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -22,7 +22,8 @@ def __init__(self): def register(self, addr: str): """Ensure this addr is known, with zero initial load.""" with self._lock: - self._load.setdefault(addr, 0) + if addr not in self._load: + self._load[addr] = 0 print( f"register-------------- self._load {self._load} -----------------", flush=True, @@ -73,10 +74,13 @@ def _listen_for_register(poller, router_socket): with prefill_cv: prefill_instances[addr] = zmq_addr prefill_load_manager.register(addr) + print(f"[SERVICE DISCOVERY][PREFILL] registered {addr}") + elif data["type"] == "D": with decode_cv: decode_instances[addr] = zmq_addr decode_load_manager.register(addr) + print(f"[SERVICE DISCOVERY][DECODE ] registered {addr}") else: print(f"Unexpected message type from {remote_address}: {data}") From 28d385d9f32e950a24d3a3c936dc5904a782bd88 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Fri, 18 Apr 2025 00:35:34 +0800 Subject: [PATCH 30/62] v1 --- flagscale/serve/run_pd_disagg_router.py | 216 ++++++++---------------- 1 file changed, 69 insertions(+), 147 deletions(-) diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index 4ff4b8010..75d405638 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -1,4 +1,3 @@ -import asyncio import os import random import socket @@ -11,81 +10,50 @@ from quart import Quart, make_response, request -# ─── Load Manager ──────────────────────────────────────────────────────────── -class LoadManager: - """Track number of in-flight tasks per instance and pick the least-loaded one.""" - +class ResourceManager: + """Thread-safe manager for prefill and decode instances.""" def __init__(self): self._lock = threading.Lock() - self._load: dict[str, int] = {} + self._instances = { + "P": {}, # type: http_address -> zmq_address + "D": {}, # decode: http_address -> zmq_address + } + self._conds = { + "P": threading.Condition(self._lock), + "D": threading.Condition(self._lock), + } - def register(self, addr: str): - """Ensure this addr is known, with zero initial load.""" - with self._lock: - if addr not in self._load: - self._load[addr] = 0 - print( - f"register-------------- self._load {self._load} -----------------", - flush=True, - ) - - def acquire(self) -> str: - """Pick the addr with minimal load and bump its count.""" - print( - f"acquire-------------- self._load {self._load} -----------------", - flush=True, - ) - with self._lock: - if not self._load: - raise RuntimeError("No instances registered") - # find instance with smallest load - addr = min(self._load.items(), key=lambda kv: kv[1])[0] - self._load[addr] += 1 - return addr - - def release(self, addr: str): - """Decrement the load count for this addr.""" + def register(self, itype: str, http_addr: str, zmq_addr: str): + """Register a new instance of type P or D.""" with self._lock: - if addr in self._load and self._load[addr] > 0: - self._load[addr] -= 1 + self._instances[itype][http_addr] = zmq_addr + self._conds[itype].notify_all() + def get_random(self, itype: str) -> tuple[str, str]: + """Get a random available instance, blocking until one is available.""" + cond = self._conds[itype] + with cond: + while not self._instances[itype]: + cond.wait() + items = list(self._instances[itype].items()) + http_addr, zmq_addr = random.choice(items) + return http_addr, zmq_addr -# managers for prefill and decode pools -prefill_load_manager = LoadManager() -decode_load_manager = LoadManager() -# ─── Service Discovery ─────────────────────────────────────────────────────── -prefill_instances: dict[str, str] = {} # http_address -> zmq_address -decode_instances: dict[str, str] = {} # http_address -> zmq_address - -prefill_cv = threading.Condition() -decode_cv = threading.Condition() - - -def _listen_for_register(poller, router_socket): +def _listen_for_register(poller, router_socket, manager: ResourceManager): while True: socks = dict(poller.poll()) if router_socket in socks: remote_address, message = router_socket.recv_multipart() data = msgpack.loads(message) - addr = data["http_address"] - zmq_addr = data["zmq_address"] - if data["type"] == "P": - with prefill_cv: - prefill_instances[addr] = zmq_addr - prefill_load_manager.register(addr) - print(f"[SERVICE DISCOVERY][PREFILL] registered {addr}") - - elif data["type"] == "D": - with decode_cv: - decode_instances[addr] = zmq_addr - decode_load_manager.register(addr) - print(f"[SERVICE DISCOVERY][DECODE ] registered {addr}") + itype = data.get("type") + if itype in ("P", "D"): + manager.register(itype, data["http_address"], data["zmq_address"]) else: - print(f"Unexpected message type from {remote_address}: {data}") + print(f"Unexpected message from {remote_address}: {data}") -def start_service_discovery(hostname, port): +def start_service_discovery(hostname: str, port: int, manager: ResourceManager) -> threading.Thread: if not hostname: hostname = socket.gethostname() if port == 0: @@ -99,16 +67,17 @@ def start_service_discovery(hostname, port): poller.register(router_socket, zmq.POLLIN) listener = threading.Thread( - target=_listen_for_register, args=[poller, router_socket], daemon=True + target=_listen_for_register, + args=(poller, router_socket, manager), + daemon=True ) listener.start() return listener -# ─── HTTP Proxy ─────────────────────────────────────────────────────────────── AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) app = Quart(__name__) - +resource_manager = ResourceManager() def random_uuid() -> str: return uuid.uuid4().hex @@ -118,95 +87,48 @@ async def forward_request(url, data, request_id): async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id, + "X-Request-Id": request_id } - async with session.post(url=url, json=data, headers=headers) as resp: - resp.raise_for_status() - async for chunk in resp.content.iter_chunked(1024): - yield chunk + async with session.post(url=url, json=data, headers=headers) as response: + if response.status == 200: + async for chunk in response.content.iter_chunked(1024): + yield chunk + else: + content = await response.read() + yield content -async def _stream_and_release(gen, manager: LoadManager, addr: str): +@app.route('/v1/completions', methods=['POST']) +async def handle_request(): try: - async for chunk in gen: - yield chunk - finally: - manager.release(addr) + original_data = await request.get_json() + prefill_data = original_data.copy() + prefill_data['max_tokens'] = 1 + prefill_addr, prefill_zmq = resource_manager.get_random("P") + decode_addr, decode_zmq = resource_manager.get_random("D") + print(f"handle_request, prefill: {prefill_addr}/{prefill_zmq}, decode: {decode_addr}/{decode_zmq}") -@app.route("/v1/completions", methods=["POST"]) -async def handle_request(): - original = await request.get_json() - prefill_data = original.copy() - prefill_data["max_tokens"] = 1 - - # pick least-loaded prefill - with prefill_cv: - prefill_addr = prefill_load_manager.acquire() - prefill_zmq = prefill_instances[prefill_addr] - print(f"Selected prefill {prefill_addr} (load bumped)") - - # finish prefill stage - prefill_req_id = f"pre_{random_uuid()}" - async for _ in forward_request( - f"http://{prefill_addr}/v1/completions", prefill_data, prefill_req_id - ): - pass - # release prefill slot - prefill_load_manager.release(prefill_addr) - - # pick least-loaded decode - with decode_cv: - decode_addr = decode_load_manager.acquire() - decode_zmq = decode_instances[decode_addr] - print(f"Selected decode {decode_addr} (load bumped)") - - # stream decode back to client, releasing when done - decode_req_id = f"dec_{random_uuid()}" - decoder = forward_request( - f"http://{decode_addr}/v1/completions", original, decode_req_id - ) - wrapped = _stream_and_release(decoder, decode_load_manager, decode_addr) - response = await make_response(wrapped) - response.timeout = None - return response - - -@app.route("/v1/chat/completions", methods=["POST"]) -async def handle_chat_request(): - original = await request.get_json() - prefill_data = original.copy() - prefill_data["max_tokens"] = 1 - - with prefill_cv: - prefill_addr = prefill_load_manager.acquire() - prefill_zmq = prefill_instances[prefill_addr] - print(f"Selected prefill(chat) {prefill_addr}") - - prefill_req_id = f"pre_chat_{random_uuid()}" - async for _ in forward_request( - f"http://{prefill_addr}/v1/chat/completions", - prefill_data, - prefill_req_id, - ): - pass - prefill_load_manager.release(prefill_addr) - - with decode_cv: - decode_addr = decode_load_manager.acquire() - decode_zmq = decode_instances[decode_addr] - print(f"Selected decode(chat) {decode_addr}") - - decode_req_id = f"dec_chat_{random_uuid()}" - decoder = forward_request( - f"http://{decode_addr}/v1/chat/completions", original, decode_req_id - ) - wrapped = _stream_and_release(decoder, decode_load_manager, decode_addr) - response = await make_response(wrapped) - response.timeout = None - return response + request_id = f"___prefill_addr_{prefill_zmq}___decode_addr_{decode_zmq}_{random_uuid()}" + + # Prefill stage + async for _ in forward_request( + f"http://{prefill_addr}/v1/completions", prefill_data, request_id): + pass + + # Decode stage and return + generator = forward_request( + f"http://{decode_addr}/v1/completions", original_data, request_id) + response = await make_response(generator) + response.timeout = None + return response + + except Exception as e: + import traceback + print("Error in disagg proxy:", e) + traceback.print_exc() if __name__ == "__main__": - start_service_discovery("0.0.0.0", 30001) - app.run(host="0.0.0.0", port=10001) + start_service_discovery("0.0.0.0", 30001, resource_manager) + app.run(host='0.0.0.0', port=10001) From f6ad7aa64280ba4b6c490f6e5d62e516954cfd28 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Fri, 18 Apr 2025 00:57:56 +0800 Subject: [PATCH 31/62] v2: dev robin --- flagscale/serve/run_pd_disagg_router.py | 180 ++++++++++++++++-------- 1 file changed, 119 insertions(+), 61 deletions(-) diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index 75d405638..b98598144 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -3,57 +3,95 @@ import socket import threading import uuid +import logging import aiohttp import msgpack import zmq from quart import Quart, make_response, request +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +# ----------------------------------------------------------------------------- +# ResourceManager: track available instances and per-instance load counts +# ----------------------------------------------------------------------------- class ResourceManager: - """Thread-safe manager for prefill and decode instances.""" def __init__(self): self._lock = threading.Lock() - self._instances = { - "P": {}, # type: http_address -> zmq_address - "D": {}, # decode: http_address -> zmq_address - } - self._conds = { - "P": threading.Condition(self._lock), - "D": threading.Condition(self._lock), + # maps resource type 'P' or 'D' to dict of + # http_addr -> {'zmq': zmq_addr, 'load': int} + self._instances: dict[str, dict[str, dict[str, object]]] = { + 'P': {}, + 'D': {}, } - def register(self, itype: str, http_addr: str, zmq_addr: str): - """Register a new instance of type P or D.""" + def register(self, rtype: str, http_addr: str, zmq_addr: str): + with self._lock: + if http_addr not in self._instances[rtype]: + self._instances[rtype][http_addr] = {'zmq': zmq_addr, 'load': 0} + logger.info(f"Registered new {rtype}-instance {http_addr} (zmq={zmq_addr})") + else: + # update zmq address if it changed + self._instances[rtype][http_addr]['zmq'] = zmq_addr + + def increment_load(self, rtype: str, http_addr: str): + with self._lock: + self._instances[rtype][http_addr]['load'] += 1 + logger.debug(f"[{rtype}] +1 load on {http_addr}, now={self._instances[rtype][http_addr]['load']}") + + def decrement_load(self, rtype: str, http_addr: str): with self._lock: - self._instances[itype][http_addr] = zmq_addr - self._conds[itype].notify_all() + self._instances[rtype][http_addr]['load'] -= 1 + logger.debug(f"[{rtype}] -1 load on {http_addr}, now={self._instances[rtype][http_addr]['load']}") + + def get_random(self, rtype: str) -> tuple[str, str]: + with self._lock: + items = list(self._instances[rtype].items()) + http_addr, info = random.choice(items) + return http_addr, info['zmq'] + + def get_least_loaded(self, rtype: str) -> tuple[str, str]: + with self._lock: + # pick the instance with the smallest load + http_addr, info = min(self._instances[rtype].items(), key=lambda kv: kv[1]['load']) + return http_addr, info['zmq'] + +# ----------------------------------------------------------------------------- +# globals & startup +# ----------------------------------------------------------------------------- +rm = ResourceManager() - def get_random(self, itype: str) -> tuple[str, str]: - """Get a random available instance, blocking until one is available.""" - cond = self._conds[itype] - with cond: - while not self._instances[itype]: - cond.wait() - items = list(self._instances[itype].items()) - http_addr, zmq_addr = random.choice(items) - return http_addr, zmq_addr +# legacy dicts + conditions (still used for condition-waiting if desired) +prefill_instances: dict[str, str] = {} +decode_instances: dict[str, str] = {} +prefill_cv = threading.Condition() +decode_cv = threading.Condition() +# choose scheduling strategy via env var: "random" or "least" +SCHEDULING_STRATEGY = os.environ.get('SCHEDULING_STRATEGY', 'random').lower() -def _listen_for_register(poller, router_socket, manager: ResourceManager): +def _listen_for_register(poller, router_socket): while True: socks = dict(poller.poll()) if router_socket in socks: remote_address, message = router_socket.recv_multipart() data = msgpack.loads(message) - itype = data.get("type") - if itype in ("P", "D"): - manager.register(itype, data["http_address"], data["zmq_address"]) + typ = data.get("type") + http_addr = data.get("http_address") + zmq_addr = data.get("zmq_address") + if typ == "P": + with prefill_cv: + prefill_instances[http_addr] = zmq_addr + rm.register('P', http_addr, zmq_addr) + elif typ == "D": + with decode_cv: + decode_instances[http_addr] = zmq_addr + rm.register('D', http_addr, zmq_addr) else: - print(f"Unexpected message from {remote_address}: {data}") + logger.warning(f"Unexpected registration message: {data}") - -def start_service_discovery(hostname: str, port: int, manager: ResourceManager) -> threading.Thread: +def start_service_discovery(hostname, port): if not hostname: hostname = socket.gethostname() if port == 0: @@ -66,23 +104,23 @@ def start_service_discovery(hostname: str, port: int, manager: ResourceManager) poller = zmq.Poller() poller.register(router_socket, zmq.POLLIN) - listener = threading.Thread( + _listener_thread = threading.Thread( target=_listen_for_register, - args=(poller, router_socket, manager), + args=[poller, router_socket], daemon=True ) - listener.start() - return listener - + _listener_thread.start() + return _listener_thread +# ----------------------------------------------------------------------------- +# HTTP proxy logic +# ----------------------------------------------------------------------------- AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) app = Quart(__name__) -resource_manager = ResourceManager() def random_uuid() -> str: return uuid.uuid4().hex - async def forward_request(url, data, request_id): async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: headers = { @@ -91,44 +129,64 @@ async def forward_request(url, data, request_id): } async with session.post(url=url, json=data, headers=headers) as response: if response.status == 200: - async for chunk in response.content.iter_chunked(1024): - yield chunk + async for chunk_bytes in response.content.iter_chunked(1024): + yield chunk_bytes else: content = await response.read() yield content - @app.route('/v1/completions', methods=['POST']) async def handle_request(): try: - original_data = await request.get_json() - prefill_data = original_data.copy() - prefill_data['max_tokens'] = 1 - - prefill_addr, prefill_zmq = resource_manager.get_random("P") - decode_addr, decode_zmq = resource_manager.get_random("D") - print(f"handle_request, prefill: {prefill_addr}/{prefill_zmq}, decode: {decode_addr}/{decode_zmq}") - - request_id = f"___prefill_addr_{prefill_zmq}___decode_addr_{decode_zmq}_{random_uuid()}" - - # Prefill stage - async for _ in forward_request( - f"http://{prefill_addr}/v1/completions", prefill_data, request_id): - pass + original_request_data = await request.get_json() + prefill_request = original_request_data.copy() + prefill_request['max_tokens'] = 1 + + # choose prefill instance + if SCHEDULING_STRATEGY == 'least': + prefill_addr, prefill_zmq = rm.get_least_loaded('P') + else: + prefill_addr, prefill_zmq = rm.get_random('P') + logger.info(f"Selected P-instance {prefill_addr} via '{SCHEDULING_STRATEGY}' strategy") + + # run prefill and track load + request_id = f"___prefill_{prefill_zmq}___decode___{random_uuid()}" + rm.increment_load('P', prefill_addr) + try: + async for _ in forward_request(f'http://{prefill_addr}/v1/completions', + prefill_request, request_id): + pass + finally: + rm.decrement_load('P', prefill_addr) + + # choose decode instance + if SCHEDULING_STRATEGY == 'least': + decode_addr, decode_zmq = rm.get_least_loaded('D') + else: + decode_addr, decode_zmq = rm.get_random('D') + logger.info(f"Selected D-instance {decode_addr} via '{SCHEDULING_STRATEGY}' strategy") + + # wrap decode generator to track load + async def tracked_decode(): + rm.increment_load('D', decode_addr) + try: + async for chunk in forward_request(f'http://{decode_addr}/v1/completions', + original_request_data, request_id): + yield chunk + finally: + rm.decrement_load('D', decode_addr) - # Decode stage and return - generator = forward_request( - f"http://{decode_addr}/v1/completions", original_data, request_id) + generator = tracked_decode() response = await make_response(generator) response.timeout = None return response except Exception as e: - import traceback - print("Error in disagg proxy:", e) - traceback.print_exc() - + import sys, traceback + logger.error("Error in proxy server", exc_info=e) + return {"error": str(e)}, 500 -if __name__ == "__main__": - start_service_discovery("0.0.0.0", 30001, resource_manager) +if __name__ == '__main__': + t = start_service_discovery("0.0.0.0", 30001) app.run(host='0.0.0.0', port=10001) + t.join() From 4f8025d8763469ac6f96828a176a0cadcbf3cb3f Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Fri, 18 Apr 2025 01:12:50 +0800 Subject: [PATCH 32/62] v2: dev robin --- flagscale/serve/run_pd_disagg_router.py | 93 ++++++++++++++----------- 1 file changed, 52 insertions(+), 41 deletions(-) diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index b98598144..ebd1e28dd 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -10,17 +10,19 @@ import zmq from quart import Quart, make_response, request +# ----------------------------------------------------------------------------- +# 日志配置 +# ----------------------------------------------------------------------------- logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -# ResourceManager: track available instances and per-instance load counts +# ResourceManager: 统一管理 P/D 实例及其负载 # ----------------------------------------------------------------------------- class ResourceManager: def __init__(self): self._lock = threading.Lock() - # maps resource type 'P' or 'D' to dict of - # http_addr -> {'zmq': zmq_addr, 'load': int} + # 每个资源类型 'P' 或 'D' 映射到 {http_addr: {'zmq': zmq_addr, 'load': int}} self._instances: dict[str, dict[str, dict[str, object]]] = { 'P': {}, 'D': {}, @@ -32,7 +34,7 @@ def register(self, rtype: str, http_addr: str, zmq_addr: str): self._instances[rtype][http_addr] = {'zmq': zmq_addr, 'load': 0} logger.info(f"Registered new {rtype}-instance {http_addr} (zmq={zmq_addr})") else: - # update zmq address if it changed + # 如果 zmq 地址更新,则同步 self._instances[rtype][http_addr]['zmq'] = zmq_addr def increment_load(self, rtype: str, http_addr: str): @@ -53,29 +55,32 @@ def get_random(self, rtype: str) -> tuple[str, str]: def get_least_loaded(self, rtype: str) -> tuple[str, str]: with self._lock: - # pick the instance with the smallest load - http_addr, info = min(self._instances[rtype].items(), key=lambda kv: kv[1]['load']) + http_addr, info = min(self._instances[rtype].items(), + key=lambda kv: kv[1]['load']) return http_addr, info['zmq'] # ----------------------------------------------------------------------------- -# globals & startup +# 全局对象与配置 # ----------------------------------------------------------------------------- rm = ResourceManager() -# legacy dicts + conditions (still used for condition-waiting if desired) +# 兼容旧版注册字典与 Condition,保留供外部等待 prefill_instances: dict[str, str] = {} decode_instances: dict[str, str] = {} prefill_cv = threading.Condition() decode_cv = threading.Condition() -# choose scheduling strategy via env var: "random" or "least" +# 调度策略:random 或 least(最少负载) SCHEDULING_STRATEGY = os.environ.get('SCHEDULING_STRATEGY', 'random').lower() +# ----------------------------------------------------------------------------- +# 服务发现:接收实例注册 +# ----------------------------------------------------------------------------- def _listen_for_register(poller, router_socket): while True: socks = dict(poller.poll()) if router_socket in socks: - remote_address, message = router_socket.recv_multipart() + remote_addr, message = router_socket.recv_multipart() data = msgpack.loads(message) typ = data.get("type") http_addr = data.get("http_address") @@ -104,16 +109,16 @@ def start_service_discovery(hostname, port): poller = zmq.Poller() poller.register(router_socket, zmq.POLLIN) - _listener_thread = threading.Thread( + listener = threading.Thread( target=_listen_for_register, args=[poller, router_socket], daemon=True ) - _listener_thread.start() - return _listener_thread + listener.start() + return listener # ----------------------------------------------------------------------------- -# HTTP proxy logic +# HTTP 代理与请求转发 # ----------------------------------------------------------------------------- AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) app = Quart(__name__) @@ -127,30 +132,42 @@ async def forward_request(url, data, request_id): "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id } - async with session.post(url=url, json=data, headers=headers) as response: - if response.status == 200: - async for chunk_bytes in response.content.iter_chunked(1024): - yield chunk_bytes + async with session.post(url=url, json=data, headers=headers) as resp: + if resp.status == 200: + async for chunk in resp.content.iter_chunked(1024): + yield chunk else: - content = await response.read() + content = await resp.read() yield content @app.route('/v1/completions', methods=['POST']) async def handle_request(): try: - original_request_data = await request.get_json() - prefill_request = original_request_data.copy() + original_data = await request.get_json() + # 预填充请求:max_tokens=1 + prefill_request = original_data.copy() prefill_request['max_tokens'] = 1 - # choose prefill instance + # 选择 Prefill 实例 if SCHEDULING_STRATEGY == 'least': prefill_addr, prefill_zmq = rm.get_least_loaded('P') else: prefill_addr, prefill_zmq = rm.get_random('P') - logger.info(f"Selected P-instance {prefill_addr} via '{SCHEDULING_STRATEGY}' strategy") + logger.info(f"Selected P-instance {prefill_addr} via '{SCHEDULING_STRATEGY}'") + + # 选择 Decode 实例 + if SCHEDULING_STRATEGY == 'least': + decode_addr, decode_zmq = rm.get_least_loaded('D') + else: + decode_addr, decode_zmq = rm.get_random('D') + logger.info(f"Selected D-instance {decode_addr} via '{SCHEDULING_STRATEGY}'") + + # 保持原始 request_id 组装格式 + request_id = ( + f"___prefill_addr_{prefill_zmq}___decode_addr_{decode_zmq}_{random_uuid()}" + ) - # run prefill and track load - request_id = f"___prefill_{prefill_zmq}___decode___{random_uuid()}" + # 执行 Prefill,并更新负载 rm.increment_load('P', prefill_addr) try: async for _ in forward_request(f'http://{prefill_addr}/v1/completions', @@ -159,34 +176,28 @@ async def handle_request(): finally: rm.decrement_load('P', prefill_addr) - # choose decode instance - if SCHEDULING_STRATEGY == 'least': - decode_addr, decode_zmq = rm.get_least_loaded('D') - else: - decode_addr, decode_zmq = rm.get_random('D') - logger.info(f"Selected D-instance {decode_addr} via '{SCHEDULING_STRATEGY}' strategy") - - # wrap decode generator to track load + # 执行 Decode,并更新负载 async def tracked_decode(): rm.increment_load('D', decode_addr) try: async for chunk in forward_request(f'http://{decode_addr}/v1/completions', - original_request_data, request_id): + original_data, request_id): yield chunk finally: rm.decrement_load('D', decode_addr) - generator = tracked_decode() - response = await make_response(generator) - response.timeout = None - return response + resp = await make_response(tracked_decode()) + resp.timeout = None + return resp except Exception as e: - import sys, traceback logger.error("Error in proxy server", exc_info=e) return {"error": str(e)}, 500 +# ----------------------------------------------------------------------------- +# 启动 +# ----------------------------------------------------------------------------- if __name__ == '__main__': - t = start_service_discovery("0.0.0.0", 30001) + listener = start_service_discovery("0.0.0.0", 30001) app.run(host='0.0.0.0', port=10001) - t.join() + listener.join() From 5b4622734b44d3a3953414205878dbadd4f7cf99 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Fri, 18 Apr 2025 01:20:13 +0800 Subject: [PATCH 33/62] v2: dev robin --- flagscale/serve/run_pd_disagg_router.py | 40 ++++++++++++++----------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index ebd1e28dd..16826fd94 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -11,18 +11,18 @@ from quart import Quart, make_response, request # ----------------------------------------------------------------------------- -# 日志配置 +# Logging configuration # ----------------------------------------------------------------------------- logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -# ResourceManager: 统一管理 P/D 实例及其负载 +# ResourceManager: unified management of P/D instances and their load # ----------------------------------------------------------------------------- class ResourceManager: def __init__(self): self._lock = threading.Lock() - # 每个资源类型 'P' 或 'D' 映射到 {http_addr: {'zmq': zmq_addr, 'load': int}} + # Each resource type 'P' or 'D' maps to {http_addr: {'zmq': zmq_addr, 'load': int}} self._instances: dict[str, dict[str, dict[str, object]]] = { 'P': {}, 'D': {}, @@ -34,7 +34,7 @@ def register(self, rtype: str, http_addr: str, zmq_addr: str): self._instances[rtype][http_addr] = {'zmq': zmq_addr, 'load': 0} logger.info(f"Registered new {rtype}-instance {http_addr} (zmq={zmq_addr})") else: - # 如果 zmq 地址更新,则同步 + # If zmq address changed, synchronize it self._instances[rtype][http_addr]['zmq'] = zmq_addr def increment_load(self, rtype: str, http_addr: str): @@ -60,21 +60,21 @@ def get_least_loaded(self, rtype: str) -> tuple[str, str]: return http_addr, info['zmq'] # ----------------------------------------------------------------------------- -# 全局对象与配置 +# Globals & configuration # ----------------------------------------------------------------------------- rm = ResourceManager() -# 兼容旧版注册字典与 Condition,保留供外部等待 +# Legacy registration dicts & Conditions retained for external waiting prefill_instances: dict[str, str] = {} decode_instances: dict[str, str] = {} prefill_cv = threading.Condition() decode_cv = threading.Condition() -# 调度策略:random 或 least(最少负载) +# Scheduling strategy: 'random' or 'least' (least load) SCHEDULING_STRATEGY = os.environ.get('SCHEDULING_STRATEGY', 'random').lower() # ----------------------------------------------------------------------------- -# 服务发现:接收实例注册 +# Service discovery: receive instance registrations # ----------------------------------------------------------------------------- def _listen_for_register(poller, router_socket): while True: @@ -118,7 +118,7 @@ def start_service_discovery(hostname, port): return listener # ----------------------------------------------------------------------------- -# HTTP 代理与请求转发 +# HTTP proxy & request forwarding # ----------------------------------------------------------------------------- AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) app = Quart(__name__) @@ -140,47 +140,51 @@ async def forward_request(url, data, request_id): content = await resp.read() yield content +# support both /v1/completions and /v1/chat/completions @app.route('/v1/completions', methods=['POST']) +@app.route('/v1/chat/completions', methods=['POST']) async def handle_request(): try: original_data = await request.get_json() - # 预填充请求:max_tokens=1 + endpoint = request.path # this will be '/v1/completions' or '/v1/chat/completions' + + # Prefill request: max_tokens=1 prefill_request = original_data.copy() prefill_request['max_tokens'] = 1 - # 选择 Prefill 实例 + # Select Prefill instance if SCHEDULING_STRATEGY == 'least': prefill_addr, prefill_zmq = rm.get_least_loaded('P') else: prefill_addr, prefill_zmq = rm.get_random('P') logger.info(f"Selected P-instance {prefill_addr} via '{SCHEDULING_STRATEGY}'") - # 选择 Decode 实例 + # Select Decode instance if SCHEDULING_STRATEGY == 'least': decode_addr, decode_zmq = rm.get_least_loaded('D') else: decode_addr, decode_zmq = rm.get_random('D') logger.info(f"Selected D-instance {decode_addr} via '{SCHEDULING_STRATEGY}'") - # 保持原始 request_id 组装格式 + # Keep original request_id composition format request_id = ( f"___prefill_addr_{prefill_zmq}___decode_addr_{decode_zmq}_{random_uuid()}" ) - # 执行 Prefill,并更新负载 + # Execute Prefill and update load rm.increment_load('P', prefill_addr) try: - async for _ in forward_request(f'http://{prefill_addr}/v1/completions', + async for _ in forward_request(f'http://{prefill_addr}{endpoint}', prefill_request, request_id): pass finally: rm.decrement_load('P', prefill_addr) - # 执行 Decode,并更新负载 + # Execute Decode and update load async def tracked_decode(): rm.increment_load('D', decode_addr) try: - async for chunk in forward_request(f'http://{decode_addr}/v1/completions', + async for chunk in forward_request(f'http://{decode_addr}{endpoint}', original_data, request_id): yield chunk finally: @@ -195,7 +199,7 @@ async def tracked_decode(): return {"error": str(e)}, 500 # ----------------------------------------------------------------------------- -# 启动 +# Startup # ----------------------------------------------------------------------------- if __name__ == '__main__': listener = start_service_discovery("0.0.0.0", 30001) From 14dd136fd84cb6e91b71d13dfe9c176a0aac5067 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Fri, 18 Apr 2025 10:19:46 +0800 Subject: [PATCH 34/62] v2: dev load --- .../config_qwen2.5_7b_pd_disaggregation.yaml | 2 +- flagscale/runner/runner_serve.py | 20 ++- flagscale/serve/run_pd_disagg_router.py | 145 +++++++++++------- 3 files changed, 108 insertions(+), 59 deletions(-) diff --git a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml index 51d774ae7..25b8144b4 100644 --- a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml +++ b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml @@ -13,7 +13,7 @@ experiment: prefill_decode_disaggregation: true prefill_num: 1 prefill_address: 10.1.1.122 - decode_num: 1 + decode_num: 2 decode_address: 10.1.1.108 runner: hostfile: examples/deepseek/conf/hostfile.txt diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 53247e99f..c7e3ef002 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -206,6 +206,14 @@ def _reset_serve_port(config): return model_port +def _attach_kv_proxy_port(config, proxy_port): + deploy_config = config.experiment.get("deploy", {}) + OmegaConf.set_struct(config, False) + deploy_config["proxy_port"] = proxy_port + OmegaConf.set_struct(config, True) + return + + def _get_inference_engine(config): serve_config = config.get("serve", []) if not serve_config: @@ -312,8 +320,12 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi d_num = deploy_config.get("decode_num", 1) ports_num = (p_num + d_num) * 2 + 1 kv_related_ports = _get_multiple_free_ports(ports_num) - kv_proxy_port = kv_related_ports.pop() - kv_proxy_port = 30001 # debug, tobe removed + pd_proxy_port = kv_related_ports.pop() + # pd_proxy_port = 30001 # debug, tobe removed + _attach_kv_proxy_port(config, pd_proxy_port) + print( + f"------------ update with port {pd_proxy_port} of new config {config} " + ) engine_args = _get_engine_args(config) command_items = ["vllm", "serve"] @@ -370,7 +382,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi "kv_port": str(kv_port), "kv_connector_extra_config": { "proxy_ip": master_ip, - "proxy_port": str(kv_proxy_port), + "proxy_port": str(pd_proxy_port), "http_port": str(http_port), }, } @@ -413,7 +425,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi "kv_port": str(kv_port), "kv_connector_extra_config": { "proxy_ip": master_ip, - "proxy_port": str(kv_proxy_port), + "proxy_port": str(pd_proxy_port), "http_port": str(http_port), }, } diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index 16826fd94..ebd577e57 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -1,68 +1,83 @@ +import logging import os import random import socket import threading import uuid -import logging import aiohttp import msgpack import zmq from quart import Quart, make_response, request -# ----------------------------------------------------------------------------- -# Logging configuration -# ----------------------------------------------------------------------------- -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +try: + import flag_scale +except Exception as e: + pass + +from flagscale import serve +from flagscale.logger import logger +from flagscale.utils import flatten_dict_to_args + +# logging.basicConfig(level=logging.INFO) +# logger = logging.getLogger(__name__) + # ----------------------------------------------------------------------------- -# ResourceManager: unified management of P/D instances and their load +# LoadManager: unified management of P/D instances and their load # ----------------------------------------------------------------------------- -class ResourceManager: +class LoadManager: def __init__(self): self._lock = threading.Lock() # Each resource type 'P' or 'D' maps to {http_addr: {'zmq': zmq_addr, 'load': int}} self._instances: dict[str, dict[str, dict[str, object]]] = { - 'P': {}, - 'D': {}, + "P": {}, + "D": {}, } def register(self, rtype: str, http_addr: str, zmq_addr: str): with self._lock: if http_addr not in self._instances[rtype]: - self._instances[rtype][http_addr] = {'zmq': zmq_addr, 'load': 0} - logger.info(f"Registered new {rtype}-instance {http_addr} (zmq={zmq_addr})") + self._instances[rtype][http_addr] = {"zmq": zmq_addr, "load": 0} + logger.info( + f"Registered new {rtype}-instance {http_addr} (zmq={zmq_addr})" + ) else: # If zmq address changed, synchronize it - self._instances[rtype][http_addr]['zmq'] = zmq_addr + self._instances[rtype][http_addr]["zmq"] = zmq_addr def increment_load(self, rtype: str, http_addr: str): with self._lock: - self._instances[rtype][http_addr]['load'] += 1 - logger.debug(f"[{rtype}] +1 load on {http_addr}, now={self._instances[rtype][http_addr]['load']}") + self._instances[rtype][http_addr]["load"] += 1 + logger.debug( + f"[{rtype}] +1 load on {http_addr}, now={self._instances[rtype][http_addr]['load']}" + ) def decrement_load(self, rtype: str, http_addr: str): with self._lock: - self._instances[rtype][http_addr]['load'] -= 1 - logger.debug(f"[{rtype}] -1 load on {http_addr}, now={self._instances[rtype][http_addr]['load']}") + self._instances[rtype][http_addr]["load"] -= 1 + logger.debug( + f"[{rtype}] -1 load on {http_addr}, now={self._instances[rtype][http_addr]['load']}" + ) def get_random(self, rtype: str) -> tuple[str, str]: with self._lock: items = list(self._instances[rtype].items()) http_addr, info = random.choice(items) - return http_addr, info['zmq'] + return http_addr, info["zmq"] def get_least_loaded(self, rtype: str) -> tuple[str, str]: with self._lock: - http_addr, info = min(self._instances[rtype].items(), - key=lambda kv: kv[1]['load']) - return http_addr, info['zmq'] + http_addr, info = min( + self._instances[rtype].items(), key=lambda kv: kv[1]["load"] + ) + return http_addr, info["zmq"] + # ----------------------------------------------------------------------------- # Globals & configuration # ----------------------------------------------------------------------------- -rm = ResourceManager() +lm = LoadManager() # Legacy registration dicts & Conditions retained for external waiting prefill_instances: dict[str, str] = {} @@ -71,7 +86,8 @@ def get_least_loaded(self, rtype: str) -> tuple[str, str]: decode_cv = threading.Condition() # Scheduling strategy: 'random' or 'least' (least load) -SCHEDULING_STRATEGY = os.environ.get('SCHEDULING_STRATEGY', 'random').lower() +SCHEDULING_STRATEGY = os.environ.get("SCHEDULING_STRATEGY", "random").lower() + # ----------------------------------------------------------------------------- # Service discovery: receive instance registrations @@ -88,14 +104,15 @@ def _listen_for_register(poller, router_socket): if typ == "P": with prefill_cv: prefill_instances[http_addr] = zmq_addr - rm.register('P', http_addr, zmq_addr) + lm.register("P", http_addr, zmq_addr) elif typ == "D": with decode_cv: decode_instances[http_addr] = zmq_addr - rm.register('D', http_addr, zmq_addr) + lm.register("D", http_addr, zmq_addr) else: logger.warning(f"Unexpected registration message: {data}") + def start_service_discovery(hostname, port): if not hostname: hostname = socket.gethostname() @@ -110,27 +127,28 @@ def start_service_discovery(hostname, port): poller.register(router_socket, zmq.POLLIN) listener = threading.Thread( - target=_listen_for_register, - args=[poller, router_socket], - daemon=True + target=_listen_for_register, args=[poller, router_socket], daemon=True ) listener.start() return listener + # ----------------------------------------------------------------------------- # HTTP proxy & request forwarding # ----------------------------------------------------------------------------- AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) app = Quart(__name__) + def random_uuid() -> str: return uuid.uuid4().hex + async def forward_request(url, data, request_id): async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id + "X-Request-Id": request_id, } async with session.post(url=url, json=data, headers=headers) as resp: if resp.status == 200: @@ -140,30 +158,33 @@ async def forward_request(url, data, request_id): content = await resp.read() yield content + # support both /v1/completions and /v1/chat/completions -@app.route('/v1/completions', methods=['POST']) -@app.route('/v1/chat/completions', methods=['POST']) +@app.route("/v1/completions", methods=["POST"]) +@app.route("/v1/chat/completions", methods=["POST"]) async def handle_request(): try: original_data = await request.get_json() - endpoint = request.path # this will be '/v1/completions' or '/v1/chat/completions' - + endpoint = ( + request.path + ) # this will be '/v1/completions' or '/v1/chat/completions' + # Prefill request: max_tokens=1 prefill_request = original_data.copy() - prefill_request['max_tokens'] = 1 + prefill_request["max_tokens"] = 1 # Select Prefill instance - if SCHEDULING_STRATEGY == 'least': - prefill_addr, prefill_zmq = rm.get_least_loaded('P') + if SCHEDULING_STRATEGY == "least": + prefill_addr, prefill_zmq = lm.get_least_loaded("P") else: - prefill_addr, prefill_zmq = rm.get_random('P') + prefill_addr, prefill_zmq = lm.get_random("P") logger.info(f"Selected P-instance {prefill_addr} via '{SCHEDULING_STRATEGY}'") # Select Decode instance - if SCHEDULING_STRATEGY == 'least': - decode_addr, decode_zmq = rm.get_least_loaded('D') + if SCHEDULING_STRATEGY == "least": + decode_addr, decode_zmq = lm.get_least_loaded("D") else: - decode_addr, decode_zmq = rm.get_random('D') + decode_addr, decode_zmq = lm.get_random("D") logger.info(f"Selected D-instance {decode_addr} via '{SCHEDULING_STRATEGY}'") # Keep original request_id composition format @@ -172,23 +193,25 @@ async def handle_request(): ) # Execute Prefill and update load - rm.increment_load('P', prefill_addr) + lm.increment_load("P", prefill_addr) try: - async for _ in forward_request(f'http://{prefill_addr}{endpoint}', - prefill_request, request_id): + async for _ in forward_request( + f"http://{prefill_addr}{endpoint}", prefill_request, request_id + ): pass finally: - rm.decrement_load('P', prefill_addr) + lm.decrement_load("P", prefill_addr) # Execute Decode and update load async def tracked_decode(): - rm.increment_load('D', decode_addr) + lm.increment_load("D", decode_addr) try: - async for chunk in forward_request(f'http://{decode_addr}{endpoint}', - original_data, request_id): + async for chunk in forward_request( + f"http://{decode_addr}{endpoint}", original_data, request_id + ): yield chunk finally: - rm.decrement_load('D', decode_addr) + lm.decrement_load("D", decode_addr) resp = await make_response(tracked_decode()) resp.timeout = None @@ -198,10 +221,24 @@ async def tracked_decode(): logger.error("Error in proxy server", exc_info=e) return {"error": str(e)}, 500 -# ----------------------------------------------------------------------------- -# Startup -# ----------------------------------------------------------------------------- -if __name__ == '__main__': - listener = start_service_discovery("0.0.0.0", 30001) - app.run(host='0.0.0.0', port=10001) + +def main(): + serve.load_args() + deploy_config = serve.task_config.experiment.get("deploy", {}) + serve_port = deploy_config.get("port", None) + # Used to register with the pd service discovery + pd_proxy_port = deploy_config.get("pd_proxy_port", None) + if not serve_port: + raise ValueError("No port specified in deploy config") + if not pd_proxy_port: + raise ValueError("No pd_proxy_port specified in deploy config") + print( + f"Starting Proxy Server...with pd_proxy_port {pd_proxy_port} and serve_port {serve_port}" + ) + listener = start_service_discovery("0.0.0.0", pd_proxy_port) + app.run(host="0.0.0.0", port=serve_port) listener.join() + + +if __name__ == "__main__": + main() From 9aaa50bc56175a0e97090a9893c10c4478184c09 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Fri, 18 Apr 2025 10:22:51 +0800 Subject: [PATCH 35/62] v2: dev load --- flagscale/runner/runner_serve.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index c7e3ef002..04312788f 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -206,10 +206,10 @@ def _reset_serve_port(config): return model_port -def _attach_kv_proxy_port(config, proxy_port): +def _attach_kv_proxy_port(config, pd_proxy_port): deploy_config = config.experiment.get("deploy", {}) OmegaConf.set_struct(config, False) - deploy_config["proxy_port"] = proxy_port + deploy_config["pd_proxy_port"] = pd_proxy_port OmegaConf.set_struct(config, True) return From c9acd717c42323559637ef8f6364eea6f6e13ead Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Fri, 18 Apr 2025 10:41:12 +0800 Subject: [PATCH 36/62] v2: dev load --- flagscale/runner/runner_serve.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 04312788f..f820e6305 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -206,14 +206,6 @@ def _reset_serve_port(config): return model_port -def _attach_kv_proxy_port(config, pd_proxy_port): - deploy_config = config.experiment.get("deploy", {}) - OmegaConf.set_struct(config, False) - deploy_config["pd_proxy_port"] = pd_proxy_port - OmegaConf.set_struct(config, True) - return - - def _get_inference_engine(config): serve_config = config.get("serve", []) if not serve_config: @@ -242,6 +234,8 @@ def _get_engine_args(config, model="vllm_model"): def _update_config_serve(config: DictConfig): + deploy_config = config.experiment.get("deploy", {}) + exp_dir = os.path.abspath(config.experiment.exp_dir) if not os.path.isdir(exp_dir): os.makedirs(exp_dir) @@ -249,6 +243,9 @@ def _update_config_serve(config: DictConfig): OmegaConf.set_struct(config, False) + if deploy_config.get("prefill_decode_disaggregation", False): + deploy_config["pd_proxy_port"] = get_free_port() + if config.get("logging", None) is None: config.logging = DictConfig({}) @@ -318,11 +315,14 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi target_port = nodes[0][1].get("port") p_num = deploy_config.get("prefill_num", 1) d_num = deploy_config.get("decode_num", 1) - ports_num = (p_num + d_num) * 2 + 1 + ports_num = (p_num + d_num) * 2 kv_related_ports = _get_multiple_free_ports(ports_num) - pd_proxy_port = kv_related_ports.pop() + pd_proxy_port = deploy_config.get("pd_proxy_port", None) + if not pd_proxy_port: + raise ValueError( + f"PD disaggregation requires a proxy port to be set." + ) # pd_proxy_port = 30001 # debug, tobe removed - _attach_kv_proxy_port(config, pd_proxy_port) print( f"------------ update with port {pd_proxy_port} of new config {config} " ) From 4e603870ea33798a8d4ef2f458a7210e599c94e1 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Fri, 18 Apr 2025 10:46:08 +0800 Subject: [PATCH 37/62] polish code --- examples/deepseek/conf/hostfile.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/deepseek/conf/hostfile.txt b/examples/deepseek/conf/hostfile.txt index e49bb3049..0d8b1e05f 100644 --- a/examples/deepseek/conf/hostfile.txt +++ b/examples/deepseek/conf/hostfile.txt @@ -1,5 +1,5 @@ # ip slots type=xxx[optional] # master node -10.1.1.122 slots=8 type=gpu +x.x.x.x slots=8 type=gpu # worker nodes -10.1.1.108 slots=8 type=gpu +x.x.x.x slots=8 type=gpu From 3f6f18ec078a3ef2d94034c4d38591e2355ba035 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Fri, 18 Apr 2025 10:54:06 +0800 Subject: [PATCH 38/62] polish code --- examples/qwen/conf/hostfile.txt | 5 +++++ flagscale/serve/run_pd_disagg_router.py | 14 +++++++------- 2 files changed, 12 insertions(+), 7 deletions(-) create mode 100644 examples/qwen/conf/hostfile.txt diff --git a/examples/qwen/conf/hostfile.txt b/examples/qwen/conf/hostfile.txt new file mode 100644 index 000000000..0d8b1e05f --- /dev/null +++ b/examples/qwen/conf/hostfile.txt @@ -0,0 +1,5 @@ +# ip slots type=xxx[optional] +# master node +x.x.x.x slots=8 type=gpu +# worker nodes +x.x.x.x slots=8 type=gpu diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index ebd577e57..52df74e8a 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -66,7 +66,7 @@ def get_random(self, rtype: str) -> tuple[str, str]: http_addr, info = random.choice(items) return http_addr, info["zmq"] - def get_least_loaded(self, rtype: str) -> tuple[str, str]: + def get_robin_loaded(self, rtype: str) -> tuple[str, str]: with self._lock: http_addr, info = min( self._instances[rtype].items(), key=lambda kv: kv[1]["load"] @@ -85,8 +85,8 @@ def get_least_loaded(self, rtype: str) -> tuple[str, str]: prefill_cv = threading.Condition() decode_cv = threading.Condition() -# Scheduling strategy: 'random' or 'least' (least load) -SCHEDULING_STRATEGY = os.environ.get("SCHEDULING_STRATEGY", "random").lower() +# Scheduling strategy: 'random' or 'robin' (robin load) +SCHEDULING_STRATEGY = os.environ.get("SCHEDULING_STRATEGY", "robin").lower() # ----------------------------------------------------------------------------- @@ -174,15 +174,15 @@ async def handle_request(): prefill_request["max_tokens"] = 1 # Select Prefill instance - if SCHEDULING_STRATEGY == "least": - prefill_addr, prefill_zmq = lm.get_least_loaded("P") + if SCHEDULING_STRATEGY == "robin": + prefill_addr, prefill_zmq = lm.get_robin_loaded("P") else: prefill_addr, prefill_zmq = lm.get_random("P") logger.info(f"Selected P-instance {prefill_addr} via '{SCHEDULING_STRATEGY}'") # Select Decode instance - if SCHEDULING_STRATEGY == "least": - decode_addr, decode_zmq = lm.get_least_loaded("D") + if SCHEDULING_STRATEGY == "robin": + decode_addr, decode_zmq = lm.get_robin_loaded("D") else: decode_addr, decode_zmq = lm.get_random("D") logger.info(f"Selected D-instance {decode_addr} via '{SCHEDULING_STRATEGY}'") From bcd23d7821d6a11faf38aacdcaf3c16a0bc80d2c Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Fri, 18 Apr 2025 10:55:47 +0800 Subject: [PATCH 39/62] polish code --- examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml index 25b8144b4..857bd24ca 100644 --- a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml +++ b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml @@ -16,7 +16,7 @@ experiment: decode_num: 2 decode_address: 10.1.1.108 runner: - hostfile: examples/deepseek/conf/hostfile.txt + hostfile: examples/qwen/conf/hostfile.txt docker: fr-v2 envs: CUDA_DEVICE_MAX_CONNECTIONS: 1 From 9dfd44a2d1ebc958c960337b033379d005c343ff Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Fri, 18 Apr 2025 10:57:59 +0800 Subject: [PATCH 40/62] polish code --- examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml index 857bd24ca..9af210a1d 100644 --- a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml +++ b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml @@ -12,9 +12,9 @@ experiment: use_fs_serve: false prefill_decode_disaggregation: true prefill_num: 1 - prefill_address: 10.1.1.122 + prefill_address: x.x.x.x # optional, default "auto" decode_num: 2 - decode_address: 10.1.1.108 + decode_address: x.x.x.x # optional, default "auto" runner: hostfile: examples/qwen/conf/hostfile.txt docker: fr-v2 From 514ceb998a35363a17784d864573d8b880ba5d93 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Fri, 18 Apr 2025 11:06:11 +0800 Subject: [PATCH 41/62] polish code --- flagscale/serve/run_pd_disagg_router.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index 52df74e8a..8d75715d9 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -71,6 +71,10 @@ def get_robin_loaded(self, rtype: str) -> tuple[str, str]: http_addr, info = min( self._instances[rtype].items(), key=lambda kv: kv[1]["load"] ) + print( + f"========== whole instance status {self._instances}==========", + flush=True, + ) return http_addr, info["zmq"] From d44b1fcf94ab9fd6a3307a5760d53454a13a8352 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Fri, 18 Apr 2025 11:10:28 +0800 Subject: [PATCH 42/62] polish code --- examples/qwen/conf/target.sh | 36 ------------------------------------ 1 file changed, 36 deletions(-) delete mode 100644 examples/qwen/conf/target.sh diff --git a/examples/qwen/conf/target.sh b/examples/qwen/conf/target.sh deleted file mode 100644 index aef7b1dad..000000000 --- a/examples/qwen/conf/target.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash - -set -x - -MODEL_NAME=/models/Qwen2.5-7B-Instruct - -source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 - -if [ -z "$PYTHONPATH" ]; then - export PYTHONPATH=/root/miniconda3/envs/flagscale-inference/lib/python3.12/site-packages:/mine/ip122/tune_qwen/github_flagscale -else - export PYTHONPATH="$PYTHONPATH:/root/miniconda3/envs/flagscale-inference/lib/python3.12/site-packages:/mine/ip122/tune_qwen/github_flagscale" -fi - -ray_path=$(realpath $(which ray)) -# clean nodes -ssh -n -p 22 10.1.1.108 "docker exec ds /bin/bash -c 'source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 && ${ray_path} stop'" -source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 && ${ray_path} stop -pkill -f 'run_inference_engine' -pkill -f 'run_fs_serve_vllm' -pkill -f 'vllm serve' - -# start cluster -# master node -source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 && ${ray_path} start --head --port=59081 --num-gpus=8 - -# worker nodes -ssh -n -p 22 10.1.1.108 "docker exec ds /bin/bash -c 'source /root/miniconda3/bin/activate flagscale-inference && export CUDA_VISIBLE_DEVICES=7 && ${ray_path} start --address=10.1.1.122:59081 --num-gpus=8'" -mkdir -p /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs -mkdir -p /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs/pids - -cd /mine/ip122/tune_qwen/github_flagscale - -cmd="CUDA_DEVICE_MAX_CONNECTIONS=1 python flagscale/serve/run_inference_engine.py --config-path=/mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs/scripts/serve.yaml --log-dir=/mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs" - -nohup bash -c "$cmd; sync" >> /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs/host_0_localhost.output 2>&1 & echo $! > /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs/pids/host_0_localhost.pid \ No newline at end of file From a4fe733e436f4e165fcda59754a6de7fdc0e404b Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Fri, 18 Apr 2025 14:12:24 +0800 Subject: [PATCH 43/62] polish code --- flagscale/serve/run_pd_disagg_router.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index 8d75715d9..566b32fd0 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -19,8 +19,7 @@ from flagscale.logger import logger from flagscale.utils import flatten_dict_to_args -# logging.basicConfig(level=logging.INFO) -# logger = logging.getLogger(__name__) +# Reference https://github.com/vllm-project/vllm/pull/15806 # ----------------------------------------------------------------------------- From 6337fe11b743c4c9dbf714c66fe0f2b08bc71969 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Fri, 18 Apr 2025 17:22:29 +0800 Subject: [PATCH 44/62] remove debug code --- examples/qwen/conf/refer.sh | 34 ---------------------------------- 1 file changed, 34 deletions(-) delete mode 100644 examples/qwen/conf/refer.sh diff --git a/examples/qwen/conf/refer.sh b/examples/qwen/conf/refer.sh deleted file mode 100644 index f0668435a..000000000 --- a/examples/qwen/conf/refer.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash - -set -x - -source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 - -if [ -z "$PYTHONPATH" ]; then - export PYTHONPATH=/root/miniconda3/envs/flagscale-inference/lib/python3.12/site-packages:/mine/ip122/tune_qwen/github_flagscale -else - export PYTHONPATH="$PYTHONPATH:/root/miniconda3/envs/flagscale-inference/lib/python3.12/site-packages:/mine/ip122/tune_qwen/github_flagscale" -fi - -ray_path=$(realpath $(which ray)) -# clean nodes -ssh -n -p 22 10.1.1.108 "docker exec ds /bin/bash -c 'source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 && ${ray_path} stop'" -source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 && ${ray_path} stop -pkill -f 'run_inference_engine' -pkill -f 'run_fs_serve_vllm' -pkill -f 'vllm serve' - -# start cluster -# master node -source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 && ${ray_path} start --head --port=59081 --num-gpus=8 - -# worker nodes -ssh -n -p 22 10.1.1.108 "docker exec ds /bin/bash -c 'source /root/miniconda3/bin/activate flagscale-inference && export GLOO_SOCKET_IFNAME=bond0 && ${ray_path} start --address=10.1.1.122:59081 --num-gpus=8'" -mkdir -p /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs -mkdir -p /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs/pids - -cd /mine/ip122/tune_qwen/github_flagscale - -cmd="CUDA_DEVICE_MAX_CONNECTIONS=1 python flagscale/serve/run_inference_engine.py --config-path=/mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs/scripts/serve.yaml --log-dir=/mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs" - -nohup bash -c "$cmd; sync" >> /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs/host_0_localhost.output 2>&1 & echo $! > /mine/ip122/tune_qwen/github_flagscale/outputs/deepseek_v3/serve_logs/pids/host_0_localhost.pid \ No newline at end of file From 42543a8287142d00d27a0b4f0d2c9af83b96aa87 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 00:07:42 +0800 Subject: [PATCH 45/62] code --- flagscale/runner/runner_serve.py | 26 ++++++++++++++++--------- flagscale/serve/run_pd_disagg_router.py | 2 +- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index f820e6305..925ca96a7 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -293,6 +293,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi except Exception as e: vllm_path = f"{root_dir}/vllm" deploy_config = config.experiment.get("deploy", {}) + envs = config.experiment.get("envs", {}) print(f"shell file ======================== {host_run_script_file}", flush=True) with open(host_run_script_file, "w") as f: f.write("#!/bin/bash\n\n") @@ -307,6 +308,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi f.write(f' export PYTHONPATH="$PYTHONPATH:{vllm_path}:{root_dir}"\n') f.write(f"fi\n") f.write(f"\n") + envs_str = " && ".join(f"export {key}={value}" for key, value in envs.items()) if nodes: if deploy_config.get("prefill_decode_disaggregation", False): @@ -322,10 +324,6 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi raise ValueError( f"PD disaggregation requires a proxy port to be set." ) - # pd_proxy_port = 30001 # debug, tobe removed - print( - f"------------ update with port {pd_proxy_port} of new config {config} " - ) engine_args = _get_engine_args(config) command_items = ["vllm", "serve"] @@ -333,9 +331,11 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi other_args = flatten_dict_to_args(engine_args, ["model", "port"]) command_items.extend(other_args) vllm_command = " ".join(command_items) - vllm_command = "nohup " + vllm_command + # vllm_command = "nohup " + vllm_command if before_start_cmd: vllm_command = f"{before_start_cmd} && " + vllm_command + if envs_str: + vllm_command = f"{envs_str} && " + vllm_command p_address = deploy_config.get("prefill_address", "auto") d_address = deploy_config.get("decode_address", "auto") tensor_parallel_size = engine_args.get("tensor_parallel_size", 1) @@ -411,8 +411,12 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi ssh_cmd = f'ssh -f -n -p {ssh_port} {d_address} "{node_cmd} > {p_instance_log_path} 2>&1 &"' f.write(f"{ssh_cmd}\n\n") else: - node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '{p_kv_config_json}' > {p_instance_log_path} 2>&1 &" - f.write(f"{node_cmd}\n\n") + p_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{p_kv_config_json}'\\''" + f.write(f"p_{j}_cmd='{p_cmd}'\n") + f.write(f"\n") + f.write( + f'nohup bash -c "$p_{i}_cmd; sync" >> {p_instance_log_path} 2>&1 &\n\n' + ) f.write("echo '=========== launch decode instance ==========='\n") @@ -454,8 +458,12 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi ssh_cmd = f'ssh -f -n -p {ssh_port} {d_address} "{node_cmd} > {d_instance_log_path} 2>&1 &"' f.write(f"{ssh_cmd}\n\n") else: - node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '{d_kv_config_json}' > {d_instance_log_path} 2>&1 &" - f.write(f"{node_cmd}\n\n") + d_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{d_kv_config_json}'\\''" + f.write(f"d_{j}_cmd='{d_cmd}'\n") + f.write(f"\n") + f.write( + f'nohup bash -c "$d_{i}_cmd; sync" >> {d_instance_log_path} 2>&1 &\n\n' + ) else: f.write(f"ray_path=$(realpath $(which ray))\n") diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index 566b32fd0..8e45a443b 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -19,7 +19,7 @@ from flagscale.logger import logger from flagscale.utils import flatten_dict_to_args -# Reference https://github.com/vllm-project/vllm/pull/15806 +# Refer to https://github.com/vllm-project/vllm/pull/15806 # ----------------------------------------------------------------------------- From 4d48282091b247fa813d5dc799da14b823bba5c3 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 00:15:27 +0800 Subject: [PATCH 46/62] fix code --- examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml | 6 +++--- examples/qwen/conf/hostfile.txt | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml index 9af210a1d..2f0e42fdf 100644 --- a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml +++ b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml @@ -12,16 +12,16 @@ experiment: use_fs_serve: false prefill_decode_disaggregation: true prefill_num: 1 - prefill_address: x.x.x.x # optional, default "auto" + prefill_address: 10.1.1.122 # optional, default "auto" decode_num: 2 - decode_address: x.x.x.x # optional, default "auto" + decode_address: 10.1.1.108 # optional, default "auto" runner: hostfile: examples/qwen/conf/hostfile.txt docker: fr-v2 envs: CUDA_DEVICE_MAX_CONNECTIONS: 1 cmds: - before_start: source /root/miniconda3/bin/activate flagscale-inference + before_start: export VLLM_USE_V1=0 && export FLAGCX_SOCKET_IFNAME=bond0 && export VLLM_USE_V1=0 && source /root/miniconda3/bin/activate flagscale-inference && export FLAGCX_PATH=/mine/ip122/tune_qwen/FlagCX/ && export FLAGCX_DEBUG=TRACE && export FLAGCX_DEBUG_SUBSYS=ALL && export NCCL_DEBUG=TRACE && export NCCL_DEBUG_SUBSYS=ALL action: run diff --git a/examples/qwen/conf/hostfile.txt b/examples/qwen/conf/hostfile.txt index 0d8b1e05f..e49bb3049 100644 --- a/examples/qwen/conf/hostfile.txt +++ b/examples/qwen/conf/hostfile.txt @@ -1,5 +1,5 @@ # ip slots type=xxx[optional] # master node -x.x.x.x slots=8 type=gpu +10.1.1.122 slots=8 type=gpu # worker nodes -x.x.x.x slots=8 type=gpu +10.1.1.108 slots=8 type=gpu From 2a0d12a2e404050f10e2959a190fefd21e771e5a Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 00:20:53 +0800 Subject: [PATCH 47/62] fix code --- flagscale/runner/runner_serve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 925ca96a7..cf2eba568 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -412,7 +412,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi f.write(f"{ssh_cmd}\n\n") else: p_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{p_kv_config_json}'\\''" - f.write(f"p_{j}_cmd='{p_cmd}'\n") + f.write(f"p_{i}_cmd='{p_cmd}'\n") f.write(f"\n") f.write( f'nohup bash -c "$p_{i}_cmd; sync" >> {p_instance_log_path} 2>&1 &\n\n' From e7dca535e21c21755d75f92deb5b34e88b4a4749 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 09:55:48 +0800 Subject: [PATCH 48/62] add pd of vllm support flagcx --- .../device_communicators/flagcx_wrapper.py | 405 +++++++++++++++ .../device_communicators/pynccl_wrapper.py | 361 ++++++++++++++ .../kv_transfer/kv_connector/factory.py | 64 +++ .../kv_transfer/kv_connector/p2p_connector.py | 300 +++++++++++ .../kv_pipe/flagcx_p2p_nccl_pipe.py | 434 ++++++++++++++++ .../kv_transfer/kv_pipe/p2p_nccl_pipe.py | 472 ++++++++++++++++++ third_party/vllm | 2 +- 7 files changed, 2037 insertions(+), 1 deletion(-) create mode 100644 flagscale/backends/vllm/vllm/distributed/device_communicators/flagcx_wrapper.py create mode 100644 flagscale/backends/vllm/vllm/distributed/device_communicators/pynccl_wrapper.py create mode 100644 flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/factory.py create mode 100644 flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py create mode 100644 flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py create mode 100644 flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py diff --git a/flagscale/backends/vllm/vllm/distributed/device_communicators/flagcx_wrapper.py b/flagscale/backends/vllm/vllm/distributed/device_communicators/flagcx_wrapper.py new file mode 100644 index 000000000..175ebe805 --- /dev/null +++ b/flagscale/backends/vllm/vllm/distributed/device_communicators/flagcx_wrapper.py @@ -0,0 +1,405 @@ +# SPDX-License-Identifier: Apache-2.0 +# reference https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/pynccl_wrapper.py + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp + +# === export types and functions from flagcx to Python === +# for the original flagcx definition, please check +# https://github.com/FlagOpen/FlagCX/blob/main/flagcx/include/flagcx.h + +flagcxResult_t = ctypes.c_int +flagcxDataType_t = ctypes.c_int +flagcxRedOp_t = ctypes.c_int +flagcxMemcpyType_t = ctypes.c_int +flagcxMemType_t = ctypes.c_int + +flagcxHandlerGroup_t = ctypes.c_void_p +flagcxComm_t = ctypes.c_void_p +flagcxEvent_t = ctypes.c_void_p +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + + +class flagcxStream(ctypes.Structure): + _fields_ = [("base", cudaStream_t)] +flagcxStream_t = ctypes.POINTER(flagcxStream) + + +class flagcxUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 256)] +flagcxUniqueId_t = ctypes.POINTER(flagcxUniqueId) + + +DEVICE_SYNCHRONIZE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t) +DEVICE_MEMCPY_FUNCTYPE = ctypes.CFUNCTYPE( + flagcxResult_t, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, + flagcxMemcpyType_t, flagcxStream_t +) +DEVICE_MEMSET_FUNCTYPE = ctypes.CFUNCTYPE( + flagcxResult_t, ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t, + flagcxMemType_t, flagcxStream_t +) +DEVICE_MALLOC_FUNCTYPE = ctypes.CFUNCTYPE( + flagcxResult_t, ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t, + flagcxMemType_t, flagcxStream_t +) +DEVICE_FREE_FUNCTYPE = ctypes.CFUNCTYPE( + flagcxResult_t, ctypes.c_void_p, flagcxMemType_t, flagcxStream_t +) +SET_DEVICE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.c_int) +GET_DEVICE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(ctypes.c_int)) +GET_DEVICE_COUNT_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(ctypes.c_int)) +GET_VENDOR_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.c_char_p) + +STREAM_CREATE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(flagcxStream_t)) +STREAM_DESTROY_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxStream_t) +STREAM_COPY_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(flagcxStream_t), ctypes.c_void_p) +STREAM_FREE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxStream_t) +STREAM_SYNCHRONIZE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxStream_t) +STREAM_QUERY_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxStream_t) +STREAM_WAIT_EVENT_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxStream_t, flagcxEvent_t) + +EVENT_CREATE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(flagcxEvent_t)) +EVENT_DESTROY_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxEvent_t) +EVENT_RECORD_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxEvent_t, flagcxStream_t) +EVENT_SYNCHRONIZE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxEvent_t) +EVENT_QUERY_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxEvent_t) +class flagcxDeviceHandle(ctypes.Structure): + _fields_ = [ + # Basic functions + ("deviceSynchronize", DEVICE_SYNCHRONIZE_FUNCTYPE), + ("deviceMemcpy", DEVICE_MEMCPY_FUNCTYPE), + ("deviceMemset", DEVICE_MEMSET_FUNCTYPE), + ("deviceMalloc", DEVICE_MALLOC_FUNCTYPE), + ("deviceFree", DEVICE_FREE_FUNCTYPE), + ("setDevice", SET_DEVICE_FUNCTYPE), + ("getDevice", GET_DEVICE_FUNCTYPE), + ("getDeviceCount", GET_DEVICE_COUNT_FUNCTYPE), + ("getVendor", GET_VENDOR_FUNCTYPE), + # Stream functions + ("streamCreate", STREAM_CREATE_FUNCTYPE), + ("streamDestroy", STREAM_DESTROY_FUNCTYPE), + ("streamCopy", STREAM_COPY_FUNCTYPE), + ("streamFree", STREAM_FREE_FUNCTYPE), + ("streamSynchronize", STREAM_SYNCHRONIZE_FUNCTYPE), + ("streamQuery", STREAM_QUERY_FUNCTYPE), + ("streamWaitEvent", STREAM_WAIT_EVENT_FUNCTYPE), + # Event functions + ("eventCreate", EVENT_CREATE_FUNCTYPE), + ("eventDestroy", EVENT_DESTROY_FUNCTYPE), + ("eventRecord", EVENT_RECORD_FUNCTYPE), + ("eventSynchronize", EVENT_SYNCHRONIZE_FUNCTYPE), + ("eventQuery", EVENT_QUERY_FUNCTYPE), + ] +flagcxDeviceHandle_t = ctypes.POINTER(flagcxDeviceHandle) + +class flagcxHandlerGroup(ctypes.Structure): + _fields_ = [ + ("uniqueId", flagcxUniqueId_t), + ("comm", flagcxComm_t), + ("devHandle", flagcxDeviceHandle_t), + ] +flagcxHandlerGroup_t = ctypes.POINTER(flagcxHandlerGroup) + + +class flagcxDataTypeEnum: + flagcxInt8 = 0 + flagcxChar = 0 + flagcxUint8 = 1 + flagcxInt32 = 2 + flagcxInt = 2 + flagcxUint32 = 3 + flagcxInt64 = 4 + flagcxUint64 = 5 + flagcxFloat16 = 6 + flagcxHalf = 6 + flagcxFloat32 = 7 + flagcxFloat = 7 + flagcxFloat64 = 8 + flagcxDouble = 8 + flagcxBfloat16 = 9 + flagcxNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.flagcxInt8 + if dtype == torch.uint8: + return cls.flagcxUint8 + if dtype == torch.int32: + return cls.flagcxInt32 + if dtype == torch.int64: + return cls.flagcxInt64 + if dtype == torch.float16: + return cls.flagcxFloat16 + if dtype == torch.float32: + return cls.flagcxFloat32 + if dtype == torch.float64: + return cls.flagcxFloat64 + if dtype == torch.bfloat16: + return cls.flagcxBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +class flagcxRedOpTypeEnum: + flagcxSum = 0 + flagcxProd = 1 + flagcxMax = 2 + flagcxMin = 3 + flagcxAvg = 4 + flagcxNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.flagcxSum + if op == ReduceOp.PRODUCT: + return cls.flagcxProd + if op == ReduceOp.MAX: + return cls.flagcxMax + if op == ReduceOp.MIN: + return cls.flagcxMin + if op == ReduceOp.AVG: + return cls.flagcxAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class FLAGCXLibrary: + exported_functions = [ + Function("flagcxHandleInit", flagcxResult_t, + [ctypes.POINTER(flagcxHandlerGroup_t)]), + Function("flagcxHandleFree", flagcxResult_t, + [flagcxHandlerGroup_t]), + Function("flagcxGetErrorString", ctypes.c_char_p, [flagcxResult_t]), + Function("flagcxGetVersion", flagcxResult_t, + [ctypes.POINTER(ctypes.c_int)]), + Function("flagcxGetUniqueId", flagcxResult_t, + [ctypes.POINTER(ctypes.POINTER(flagcxUniqueId))]), + # Note that flagcxComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("flagcxCommInitRank", flagcxResult_t, [ + ctypes.POINTER(flagcxComm_t), ctypes.c_int, ctypes.POINTER(flagcxUniqueId), + ctypes.c_int + ]), + # Note that flagcxStream_t is a pointer type, so the last argument + # is a pointer + Function("flagcxAllReduce", flagcxResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, flagcxDataType_t, + flagcxRedOp_t, flagcxComm_t, flagcxStream_t + ]), + + # Note that flagcxStream_t is a pointer type, so the last argument + # is a pointer + Function("flagcxAllGather", flagcxResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, flagcxDataType_t, + flagcxComm_t, flagcxStream_t + ]), + + # Note that flagcxStream_t is a pointer type, so the last argument + # is a pointer + Function("flagcxReduceScatter", flagcxResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, flagcxDataType_t, + flagcxRedOp_t, flagcxComm_t, flagcxStream_t + ]), + + Function("flagcxSend", flagcxResult_t, [ + buffer_type, ctypes.c_size_t, flagcxDataType_t, ctypes.c_int, + flagcxComm_t, flagcxStream_t + ]), + + Function("flagcxRecv", flagcxResult_t, [ + buffer_type, ctypes.c_size_t, flagcxDataType_t, ctypes.c_int, + flagcxComm_t, flagcxStream_t + ]), + + Function("flagcxBroadcast", flagcxResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, flagcxDataType_t, + ctypes.c_int, flagcxComm_t, flagcxStream_t + ]), + + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # flagcxResult_t flagcxCommDestroy(flagcxComm_t comm); + Function("flagcxCommDestroy", flagcxResult_t, [flagcxComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + + try: + if so_file not in FLAGCXLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + FLAGCXLibrary.path_to_library_cache[so_file] = lib + self.lib = FLAGCXLibrary.path_to_library_cache[so_file] + except Exception as e: + raise e + + if so_file not in FLAGCXLibrary.path_to_dict_mapping: + _funcs: Dict[str, Any] = {} + for func in FLAGCXLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + FLAGCXLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = FLAGCXLibrary.path_to_dict_mapping[so_file] + + # init flagcx handler to call device-related apis + self.handler = flagcxHandlerGroup_t() + self.FLAGCX_CHECK(self._funcs["flagcxHandleInit"](ctypes.byref(self.handler))) + + def __del__(self): + # free flagcx handler + self.FLAGCX_CHECK(self._funcs["flagcxHandleFree"](self.handler)) + + def flagcxGetErrorString(self, result: flagcxResult_t) -> str: + return self._funcs["flagcxGetErrorString"](result).decode("utf-8") + + def FLAGCX_CHECK(self, result: flagcxResult_t) -> None: + if result != 0: + error_str = self.flagcxGetErrorString(result) + raise RuntimeError(f"FLAGCX error: {error_str}") + + def flagcxGetVersion(self) -> str: + version = ctypes.c_int() + self.FLAGCX_CHECK(self._funcs["flagcxGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def flagcxGetUniqueId(self) -> flagcxUniqueId: + unique_id = ctypes.POINTER(flagcxUniqueId)() + self.FLAGCX_CHECK(self._funcs["flagcxGetUniqueId"]( + ctypes.byref(unique_id))) + return unique_id + + def unique_id_from_bytes(self, data: bytes) -> flagcxUniqueId: + """ + Reconstructs an `ncclUniqueId` object from bytes data. + Args: + data: Must be a 128-byte data block (matching NCCL's unique_id). + Returns: + ncclUniqueId: The reconstructed NCCL Unique ID object. + Raises: + ValueError: If the input data length is not 128 bytes. + """ + if len(data) != 256: + raise ValueError( + f"Expected 256 bytes for ncclUniqueId, got {len(data)} bytes") + + unique_id = flagcxUniqueId() + ctypes.memmove(ctypes.addressof(unique_id.internal), data, 256) + return unique_id + + def flagcxCommInitRank(self, world_size: int, unique_id: flagcxUniqueId, + rank: int) -> flagcxComm_t: + comm = flagcxComm_t() + self.FLAGCX_CHECK(self._funcs["flagcxCommInitRank"](ctypes.byref(comm), + world_size, unique_id, + rank)) + return comm + + def flagcxAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: flagcxComm_t, + stream: flagcxStream_t) -> None: + # `datatype` actually should be `flagcxDataType_t` + # and `op` should be `flagcxRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.FLAGCX_CHECK(self._funcs["flagcxAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def flagcxReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: flagcxComm_t, + stream: flagcxStream_t) -> None: + # `datatype` actually should be `flagcxDataType_t` + # and `op` should be `flagcxRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.FLAGCX_CHECK(self._funcs["flagcxReduceScatter"](sendbuff, recvbuff, + count, datatype, op, + comm, stream)) + + def flagcxAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, comm: flagcxComm_t, + stream: flagcxStream_t) -> None: + # `datatype` actually should be `flagcxDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.FLAGCX_CHECK(self._funcs["flagcxAllGather"](sendbuff, recvbuff, count, + datatype, comm, stream)) + + def flagcxSend(self, sendbuff: buffer_type, count: int, datatype: int, + dest: int, comm: flagcxComm_t, stream: flagcxStream_t) -> None: + self.FLAGCX_CHECK(self._funcs["flagcxSend"](sendbuff, count, datatype, + dest, comm, stream)) + + def flagcxRecv(self, recvbuff: buffer_type, count: int, datatype: int, + src: int, comm: flagcxComm_t, stream: flagcxStream_t) -> None: + self.FLAGCX_CHECK(self._funcs["flagcxRecv"](recvbuff, count, datatype, src, + comm, stream)) + + def flagcxBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, root: int, comm: flagcxComm_t, + stream: flagcxStream_t) -> None: + self.FLAGCX_CHECK(self._funcs["flagcxBroadcast"](sendbuff, recvbuff, count, + datatype, root, comm, + stream)) + + def flagcxCommDestroy(self, comm: flagcxComm_t) -> None: + self.FLAGCX_CHECK(self._funcs["flagcxCommDestroy"](comm)) + + def adaptor_stream_create(self): + new_stream = flagcxStream_t() + self.FLAGCX_CHECK(self.handler.contents.devHandle.contents.streamCreate(ctypes.byref(new_stream))) + return new_stream + + def adaptor_stream_copy(self, old_stream): + new_stream = flagcxStream_t() + self.FLAGCX_CHECK(self.handler.contents.devHandle.contents.streamCopy(ctypes.byref(new_stream), ctypes.byref(cudaStream_t(old_stream.cuda_stream)))) + return new_stream + + def adaptor_stream_free(self, stream): + self.FLAGCX_CHECK(self.handler.contents.devHandle.contents.streamFree(stream)) + + def adaptor_stream_destroy(self, stream): + self.FLAGCX_CHECK(self.handler.contents.devHandle.contents.streamDestroy(stream)) + + def sync_stream(self, stream): + self.FLAGCX_CHECK(self.handler.contents.devHandle.contents.streamSynchronize(stream)) + + +__all__ = [ + "FLAGCXLibrary", "flagcxDataTypeEnum", "flagcxRedOpTypeEnum", "flagcxUniqueId", + "flagcxHandlerGroup_t", "flagcxComm_t", "flagcxStream_t", "flagcxEvent_t", "buffer_type", "cudaStream_t" +] \ No newline at end of file diff --git a/flagscale/backends/vllm/vllm/distributed/device_communicators/pynccl_wrapper.py b/flagscale/backends/vllm/vllm/distributed/device_communicators/pynccl_wrapper.py new file mode 100644 index 000000000..83619c27f --- /dev/null +++ b/flagscale/backends/vllm/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -0,0 +1,361 @@ +# SPDX-License-Identifier: Apache-2.0 + +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp + +from vllm.logger import init_logger +from vllm.utils import find_nccl_library + +logger = init_logger(__name__) + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, + [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, + [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("ncclCommInitRank", ncclResult_t, [ + ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, + ctypes.c_int + ]), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllGather", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclReduceScatter( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclReduceScatter", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function("ncclSend", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function("ncclRecv", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclBroadcast( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, int root, ncclComm_t comm, + # cudaStream_t stream); + Function("ncclBroadcast", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ctypes.c_int, ncclComm_t, cudaStream_t + ]), + + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_nccl_library() + + try: + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + except Exception as e: + logger.error( + "Failed to load NCCL library from %s. " + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s. " + "If you already have the library, please set the " + "environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", so_file, + platform.platform()) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs: Dict[str, Any] = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["ncclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( + ctypes.byref(unique_id))) + return unique_id + + def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId: + """ + Reconstructs an `ncclUniqueId` object from bytes data. + + Args: + data: Must be a 128-byte data block (matching NCCL's unique_id). + + Returns: + ncclUniqueId: The reconstructed NCCL Unique ID object. + + Raises: + ValueError: If the input data length is not 128 bytes. + """ + if len(data) != 128: + raise ValueError( + f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes") + + unique_id = ncclUniqueId() + ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128) + return unique_id + + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, + rank: int) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), + world_size, unique_id, + rank)) + return comm + + def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff, + count, datatype, op, + comm, stream)) + + def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, + datatype, comm, stream)) + + def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, + dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, + dest, comm, stream)) + + def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, + src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, + comm, stream)) + + def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, root: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count, + datatype, root, comm, + stream)) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", + "ncclComm_t", "cudaStream_t", "buffer_type" +] diff --git a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/factory.py b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/factory.py new file mode 100644 index 000000000..39c9809ce --- /dev/null +++ b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 + +import importlib +from typing import TYPE_CHECKING, Callable, Dict, Type + +from .base import KVConnectorBase + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +class KVConnectorFactory: + _registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {} + + @classmethod + def register_connector(cls, name: str, module_path: str, + class_name: str) -> None: + """Register a connector with a lazy-loading module and class name.""" + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + + def loader() -> Type[KVConnectorBase]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + + @classmethod + def create_connector(cls, rank: int, local_rank: int, + config: "VllmConfig") -> KVConnectorBase: + connector_name = config.kv_transfer_config.kv_connector + if connector_name not in cls._registry: + raise ValueError(f"Unsupported connector type: {connector_name}") + + connector_cls = cls._registry[connector_name]() + return connector_cls(rank, local_rank, config) + + +# Register various connectors here. +# The registration should not be done in each individual file, as we want to +# only load the files corresponding to the current connector. +KVConnectorFactory.register_connector( + "P2pConnector", "vllm.distributed.kv_transfer.kv_connector.p2p_connector", + "P2pConnector") + +KVConnectorFactory.register_connector( + "PyNcclConnector", + "vllm.distributed.kv_transfer.kv_connector.simple_connector", + "SimpleConnector") + +KVConnectorFactory.register_connector( + "MooncakeConnector", + "vllm.distributed.kv_transfer.kv_connector.simple_connector", + "SimpleConnector") + +KVConnectorFactory.register_connector( + "LMCacheConnector", + "vllm.distributed.kv_transfer.kv_connector.lmcache_connector", + "LMCacheConnector") + +KVConnectorFactory.register_connector( + "MooncakeStoreConnector", + "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector", + "MooncakeStoreConnector") \ No newline at end of file diff --git a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py new file mode 100644 index 000000000..b0e2d96ba --- /dev/null +++ b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py @@ -0,0 +1,300 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import TYPE_CHECKING, List, Tuple, Union + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_pipe.p2p_nccl_pipe import P2pNcclPipe +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class P2pConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + self.rank = rank + self.config = config.kv_transfer_config + self.tp_size = config.parallel_config.tensor_parallel_size + self.is_deepseek_mla = config.model_config.is_deepseek_mla + self.use_mla_opt = not envs.VLLM_MLA_DISABLE + + assert self.config.kv_connector == "P2pConnector" + + self.lookup_buffer_size = self.config.kv_buffer_size + + self.p2p_nccl_pipe = P2pNcclPipe( + local_rank=local_rank, + config=self.config, + hostname="", + port_offset=rank, + ) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + # input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens + request_ids = list(model_input.request_ids_to_seq_ids.keys()) + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + + model_config = model_executable.model.config + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + + # Deepseek's MLA (Multi-head Latent Attention) uses two different + # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0. + # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied, + # resulting in a kv_cache shape of [num_blks, blk_size, 1, + # kv_lora_rank + qk_rope_head_dim]. + # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading + # to a kv_cache shape of [2, num_blks, blk_size, + # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. + # For more details, see vllm/attention/backends/mla/common.py. + if self.is_deepseek_mla and self.use_mla_opt: + head_size = model_config.kv_lora_rank + \ + model_config.qk_rope_head_dim + num_heads = 1 + elif self.is_deepseek_mla and not self.use_mla_opt: + head_size = model_config.qk_nope_head_dim + \ + model_config.qk_rope_head_dim + else: + head_size = getattr(model_config, "head_dim", + int(hidden_size // num_attention_heads)) + + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME(Kuntai): This assume that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You have some decode requests while using " + "SimpleConnector. Their KVCache won't be sent.") + break + + # current_tokens = input_tokens_tensor[start_pos:end_pos] + + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + + if self.is_deepseek_mla and self.use_mla_opt: + key_cache = kv_cache.reshape(-1, num_heads, head_size) + value_cache = kv_cache.reshape(-1, num_heads, head_size) + else: + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + + request_id = request_ids[idx] + ip, port = self.parse_request_id(request_id, True) + remote_address = ip + ":" + str(port + self.rank) + + self.p2p_nccl_pipe.send_tensor(request_id + "keys", keys, + remote_address) + self.p2p_nccl_pipe.send_tensor(request_id + "values", values, + remote_address) + self.p2p_nccl_pipe.send_tensor( + request_id + "hidden", + hidden_or_intermediate_states[start_pos:end_pos], + remote_address) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + # When bypass_model_exec is set to False, it means that at least for one + # request its corresponding KV cache or hidden state is missing. + # In this case we need to do prefilling to recompute missing KV cache + # and hidden states. + bypass_model_exec = True + + model_config = model_executable.model.config + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + request_ids = list(model_input.request_ids_to_seq_ids.keys()) + + hidden_or_intermediate_states_for_one_req = [] + + input_tokens_list = [] + num_computed_tokens_list = [] + start_pos_list = [] + + # enumerate different requests + # FIXME(Kuntai): This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # This can happen during inflight batching. See: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You should set --enable_chunked_prefill=False " + "and --max_num_batched_tokens " + "should be equal to --max_seq_len_to_capture") + bypass_model_exec = False + assert start_pos == num_prefill_tokens + break + + current_tokens = input_tokens_tensor[start_pos:end_pos] + num_tokens = slen + + # collecting data for rebuilding the input + input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos) + + request_id = request_ids[idx] + ip, port = self.parse_request_id(request_id, False) + remote_address = ip + ":" + str(port + self.rank) + + keys = self.p2p_nccl_pipe.recv_tensor(request_id + "keys", + remote_address) + values = self.p2p_nccl_pipe.recv_tensor(request_id + "values", + remote_address) + hidden = self.p2p_nccl_pipe.recv_tensor(request_id + "hidden", + remote_address) + + num_computed_tokens = current_tokens.shape[0] + num_computed_tokens_list.append(num_computed_tokens) + + # check if both KV cache and the hidden states are received + # If not, need to redo the forwarding to compute missing states + if not all([(num_computed_tokens == num_tokens), keys is not None, + values is not None, hidden is not None]): + bypass_model_exec = False + break + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # put received KV caches into paged memory + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + + kv_cache = kv_caches[i - model_executable.model.start_layer] + layer = model_executable.model.layers[i] + + if self.is_deepseek_mla and self.use_mla_opt: + layer.self_attn.attn = layer.self_attn.mla_attn + k_c_normed_k_pe = keys[ + i - model_executable.model.start_layer].to( + kv_cache.device).squeeze(1) + k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank] + k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:] + ops.concat_and_cache_mla( + k_c_normed, + k_pe, + kv_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + ) + else: + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys[i - model_executable.model.start_layer].to( + key_cache.device), + values[i - model_executable.model.start_layer].to( + value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # Here we will fall back to normal model forwarding + # But optionally you can adjust model_input so that you only do + # prefilling on those tokens that are missing KV caches. + logger.warning( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + @staticmethod + def parse_request_id(request_id: str, is_prefill=True) -> Tuple[str, int]: + logger.debug("parse_request_id, request_id: %s, is_prefill: %s", + request_id, is_prefill) + # Regular expression to match the string hostname and integer port + if is_prefill: + pattern = r"___decode_addr_(.*):(\d+)" + else: + pattern = r"___prefill_addr_(.*):(\d+)___" + + # Use re.search to find the pattern in the request_id + match = re.search(pattern, request_id) + if match: + # Extract the ranks + ip = match.group(1) + port = int(match.group(2)) + + logger.debug("parse_request_id, request_id: %s, ip: %s, port: %s", + request_id, ip, str(port)) + return ip, port + raise ValueError( + f"Request id {request_id} does not contain hostname and port") + + def close(self): + self.p2p_nccl_pipe.close() + diff --git a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py new file mode 100644 index 000000000..e62a5f656 --- /dev/null +++ b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py @@ -0,0 +1,434 @@ +# SPDX-License-Identifier: Apache-2.0 + +import logging +import threading +import time +import typing +from collections import deque +from typing import Any, Deque, Dict, List, Optional + +import msgpack +import torch +import zmq +import os + +from vllm.config import KVTransferConfig + + +from vllm.distributed.device_communicators.flagcx_wrapper import ( + FLAGCXLibrary, buffer_type, flagcxComm_t, flagcxDataTypeEnum, + flagcxRedOpTypeEnum, flagcxUniqueId, cudaStream_t) + +from vllm.utils import current_stream, get_ip + +logger = logging.getLogger(__name__) + + +class P2pNcclPipe: + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + hostname: str = "", + port_offset: int = 0, + library_path: Optional[str] = None) -> None: + self.config = config + self.rank = port_offset + self.local_rank = local_rank + self.device = torch.device(f"cuda:{self.local_rank}") + flagcx_path = os.getenv('FLAGCX_PATH') + library_path=os.path.join(flagcx_path, "build/lib/libflagcx.so") + print("============== flagcx library_path ============", library_path, flush=True) + self.nccl = FLAGCXLibrary(library_path) + + if not hostname: + hostname = get_ip() + port = self.config.kv_port + port_offset + if port == 0: + raise ValueError("Port cannot be 0") + self._hostname = hostname + self._port = port + + # Each card corresponds to a ZMQ address. + self.zmq_address = f"{self._hostname}:{self._port}" + + # The `http_port` must be consistent with the port of OpenAI. + self.http_address = ( + f"{self._hostname}:" + f"{self.config.kv_connector_extra_config['http_port']}") + + # If `proxy_ip` or `proxy_port` is `""`, + # then the ping thread will not be enabled. + proxy_ip = self.config.get_from_extra_config("proxy_ip", "") + proxy_port = self.config.get_from_extra_config("proxy_port", "") + if proxy_ip == "" or proxy_port == "": + self.proxy_address = "" + else: + self.proxy_address = proxy_ip + ":" + proxy_port + + self.context = zmq.Context() + self.router_socket = self.context.socket(zmq.ROUTER) + self.router_socket.bind(f"tcp://{self.zmq_address}") + + self.poller = zmq.Poller() + self.poller.register(self.router_socket, zmq.POLLIN) + + self.send_store_cv = threading.Condition() + self.send_queue_cv = threading.Condition() + self.recv_store_cv = threading.Condition() + self.comm_cv = threading.Condition() + + # The sending type includes tree mutually exclusive options: + # PUT, GET, PUT_ASYNC. + self.send_type = self.config.get_from_extra_config("send_type", "PUT") + if self.send_type == "GET": + self.send_store: Dict[str, + torch.Tensor] = {} # tensor_id: torch.Tensor + else: + # PUT or PUT_ASYNC + self.send_queue: Deque[ + List[Any]] = deque() # tensor_id: torch.Tensor + if self.send_type == "PUT_ASYNC": + self._send_thread = threading.Thread(target=self._send_async, + daemon=True) + self._send_thread.start() + + self.recv_store: Dict[str, + torch.Tensor] = {} # tensor_id: torch.Tensor + self.socks: Dict[str, Any] = {} # remote_address: client socket + self.comms: Dict[str, Any] = {} # remote_address: (flagcxComm_t, rank) + + self.buffer_size = 0 + self.buffer_size_threshold = self.config.kv_buffer_size + + self._listener_thread = threading.Thread( + target=self._listen_for_requests, daemon=True) + self._listener_thread.start() + + self._ping_thread = None + if port_offset == 0 and self.proxy_address != "": + self._ping_thread = threading.Thread(target=self._ping, + daemon=True) + self._ping_thread.start() + + def _create_connect(self, remote_address: typing.Optional[str] = None): + assert remote_address is not None + if remote_address not in self.socks: + sock = self.context.socket(zmq.DEALER) + sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) + sock.connect(f"tcp://{remote_address}") + self.socks[remote_address] = sock + if remote_address in self.comms: + logger.info("👋comm exists, remote_address:%s, comms:%s", + remote_address, self.comms) + return sock, self.comms[remote_address] + + unique_id = self.nccl.flagcxGetUniqueId() + data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)} + sock.send(msgpack.dumps(data)) + + with torch.cuda.device(self.device): + rank = 0 + comm: flagcxComm_t = self.nccl.flagcxCommInitRank( + 2, unique_id, rank) + self.comms[remote_address] = (comm, rank) + logger.info("🤝flagcxCommInitRank Success, %s👉%s, MyRank: %s", + self.zmq_address, remote_address, rank) + + return self.socks[remote_address], self.comms[remote_address] + + def send_tensor( + self, + tensor_id: str, + tensor: torch.Tensor, + remote_address: typing.Optional[str] = None, + ) -> bool: + if remote_address is None: + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.recv_store_cv.notify() + return True + else: + if self.send_type == "PUT": + return self._send_sync(tensor_id, tensor, remote_address) + elif self.send_type == "PUT_ASYNC": + with self.send_queue_cv: + self.send_queue.append([tensor_id, remote_address, tensor]) + self.send_queue_cv.notify() + else: # GET + with self.send_store_cv: + tensor_size = tensor.element_size() * tensor.numel() + while (self.buffer_size + tensor_size + > self.buffer_size_threshold): + oldest_tenser_id = next(iter(self.send_store)) + oldest_tenser = self.send_store.pop(oldest_tenser_id) + oldest_tenser_size = oldest_tenser.element_size( + ) * oldest_tenser.numel() + self.buffer_size -= oldest_tenser_size + logger.info( + "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," + " buffer_size:%d, oldest_tenser_size:%d, rank:%d", + remote_address, tensor_id, tensor_size, + self.buffer_size, oldest_tenser_size, self.rank) + + self.send_store[tensor_id] = tensor + self.buffer_size += tensor_size + logger.info( + "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " + "shape:%s, rank:%d, buffer_size:%d(%.2f%%)", + remote_address, tensor_id, tensor_size, tensor.shape, + self.rank, self.buffer_size, + self.buffer_size / self.buffer_size_threshold * 100) + + return True + + def recv_tensor( + self, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ) -> torch.Tensor: + if self.send_type == "PUT" or self.send_type == "PUT_ASYNC": + start_time = time.time() + with self.recv_store_cv: + while tensor_id not in self.recv_store: + self.recv_store_cv.wait() + # TODO:Abatom, To avoid an overly large dictionary. + # tensor = self.recv_store.pop(tensor_id) + tensor = self.recv_store[tensor_id] + self.recv_store[tensor_id] = None + duration = time.time() - start_time + if tensor is not None: + self.buffer_size -= (tensor.element_size() * tensor.numel()) + logger.info( + "🔵[PUT]Recv From %s, tensor_id:%s, shape:%s, " + "duration:%.3fms, size:%.3fGB, rank:%d", remote_address, + tensor_id, tensor.shape, duration * 1000, + tensor.element_size() * tensor.numel() / 1024**3, + self.rank) + else: + logger.warning( + "🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, " + "rank:%d", remote_address, tensor_id, duration * 1000, + self.rank) + return tensor + + # GET + if remote_address is None: + return None + + if remote_address not in self.socks: + self._create_connect(remote_address) + + sock = self.socks[remote_address] + comm, rank = self.comms[remote_address] + + data = {"cmd": "GET", "tensor_id": tensor_id} + sock.send(msgpack.dumps(data)) + + message = sock.recv() + data = msgpack.loads(message) + if data["ret"] != 0: + logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d", + remote_address, tensor_id, data["ret"]) + return None + + tensor = torch.empty(data["shape"], + dtype=getattr(torch, data["dtype"]), + device=self.device) + + start_time = time.time() + self._recv(comm, tensor, rank ^ 1) + duration = time.time() - start_time + logger.info( + "🔵[GET]Recv From %s, tensor_id:%s, shape:%s, duration:%.3fms, " + "size:%.3fGB, rank:%d", remote_address, tensor_id, tensor.shape, + duration * 1000, + tensor.element_size() * tensor.numel() / 1024**3, self.rank) + + return tensor + + def _listen_for_requests(self): + while True: + socks = dict(self.poller.poll()) + if self.router_socket in socks: + remote_address, message = self.router_socket.recv_multipart() + data = msgpack.loads(message) + logger.debug("Received message from %s, data:%s", + remote_address.decode(), data) + if data["cmd"] == "NEW": + unique_id = self.nccl.unique_id_from_bytes( + bytes(data["unique_id"])) + with torch.cuda.device(self.device): + rank = 1 + comm: flagcxComm_t = self.nccl.flagcxCommInitRank( + 2, unique_id, rank) + self.comms[remote_address.decode()] = (comm, rank) + logger.info( + "🤝flagcxCommInitRank Success, %s👈%s, MyRank:%s", + self.zmq_address, remote_address.decode(), rank) + elif data["cmd"] == "PUT": + try: + tensor = torch.empty(data["shape"], + dtype=getattr( + torch, data["dtype"]), + device=self.device) + + tensor_size = tensor.element_size() * tensor.numel() + if (self.buffer_size + tensor_size + > self.buffer_size_threshold): + self.router_socket.send_multipart( + [remote_address, b"2"]) + logger.warning( + "🔴[PUT]Recv Tensor, Out Of Threshold, " + "%s👈%s, data:%s", self.zmq_address, + remote_address.decode(), data) + tensor = None + else: + self.buffer_size += tensor_size + self.router_socket.send_multipart( + [remote_address, b"0"]) + comm, rank = self.comms[remote_address.decode()] + self._recv(comm, tensor, rank ^ 1) + logger.info( + "🔵[PUT]Recv Tensor, %s👈%s, MyRank:%s, data:%s, " + "shape:%s", self.zmq_address, + remote_address.decode(), rank, data, tensor.shape) + + tensor_id = data["tensor_id"] + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.recv_store_cv.notify() + + except torch.cuda.OutOfMemoryError: + self.router_socket.send_multipart( + [remote_address, b"1"]) + logger.warning( + "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, " + "data:%s", self.zmq_address, + remote_address.decode(), data) + + elif data["cmd"] == "GET": + tensor_id = data["tensor_id"] + with self.send_store_cv: + tensor = self.send_store.pop(tensor_id, None) + if tensor is not None: + data = { + "ret": 0, + "shape": tensor.shape, + "dtype": + str(tensor.dtype).replace("torch.", "") + } + # LRU + self.send_store[tensor_id] = tensor + else: + data = {"ret": 1} + + self.router_socket.send_multipart( + [remote_address, msgpack.dumps(data)]) + + if data["ret"] == 0: + self._send(comm, tensor.to(self.device), rank ^ 1) + + logger.info( + "🔵[GET]Send Tensor, %s👉%s, " + "MyRank:%s, data:%s", self.zmq_address, + remote_address.decode(), rank, data) + else: + logger.warning( + "🚧Unexpected, Received message from %s, data:%s", + remote_address, data) + + # Asynchronous sending may cause conflicts between P2P NCCL and + # NCCL used in TP/PP, which can lead to deadlock issues. + def _send_async(self): + while True: + with self.send_queue_cv: + while not self.send_queue: + self.send_queue_cv.wait() + tensor_id, remote_address, tensor = self.send_queue.popleft() + self._send_sync(tensor_id, tensor, remote_address) + + def _send_sync( + self, + tensor_id: str, + tensor: torch.Tensor, + remote_address: typing.Optional[str] = None, + ) -> bool: + if remote_address is None: + return False + if remote_address not in self.socks: + self._create_connect(remote_address) + + sock = self.socks[remote_address] + comm, rank = self.comms[remote_address] + data = { + "cmd": "PUT", + "tensor_id": tensor_id, + "shape": tensor.shape, + "dtype": str(tensor.dtype).replace("torch.", "") + } + sock.send(msgpack.dumps(data)) + + response = sock.recv() + if response != b"0": + # with self.send_queue_cv: + # self.send_queue.append([tensor_id, remote_address, tensor]) + # self.send_queue_cv.notify() + logger.warning( + "🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, " + "MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s", + self.zmq_address, remote_address, rank, data, tensor.shape, + tensor.element_size() * tensor.numel() / 1024**3, + response.decode()) + return False + + self._send(comm, tensor.to(self.device), rank ^ 1) + logger.info("🔵Send Tensor, %s👉%s, MyRank:%s, data:%s, tensor:%s", + self.zmq_address, remote_address, rank, data, tensor.shape) + return True + + def _ping(self): + sock = self.context.socket(zmq.DEALER) + sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) + logger.debug("ping start, zmq_address:%s", self.zmq_address) + sock.connect(f"tcp://{self.proxy_address}") + data = { + "type": "P" if self.config.is_kv_producer else "D", + "http_address": self.http_address, + "zmq_address": self.zmq_address + } + while True: + sock.send(msgpack.dumps(data)) + time.sleep(3) + + def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None): + assert tensor.device == self.device, ( + f"this flagcx communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + + with self.comm_cv: + self.nccl.flagcxSend(buffer_type(tensor.data_ptr()), tensor.numel(), + flagcxDataTypeEnum.from_torch(tensor.dtype), dst, + comm, cudaStream_t(stream.cuda_stream)) + + def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None): + assert tensor.device == self.device, ( + f"this flagcx communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + + with self.comm_cv: + self.nccl.flagcxRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + flagcxDataTypeEnum.from_torch(tensor.dtype), src, + comm, cudaStream_t(stream.cuda_stream)) + + def close(self) -> None: + self._listener_thread.join() + if self.send_type == "PUT_ASYNC": + self._send_thread.join() + if self._ping_thread is not None: + self._ping_thread.join() + diff --git a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py new file mode 100644 index 000000000..c99dfbad2 --- /dev/null +++ b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py @@ -0,0 +1,472 @@ +# SPDX-License-Identifier: Apache-2.0 + +import logging +import threading +import time +import typing +from collections import deque +from typing import Any, Deque, Dict, List, Optional + +import msgpack +import torch +import zmq +import ctypes + +from vllm.config import KVTransferConfig +# from vllm.distributed.device_communicators.pynccl_wrapper import ( +# NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum) +from vllm.distributed.device_communicators.flagcx_wrapper import ( + FLAGCXLibrary, + buffer_type, + cudaStream_t, + flagcxComm_t, + flagcxDataTypeEnum, +) +from vllm.utils import current_stream, get_ip + +logger = logging.getLogger(__name__) + + +class P2pNcclPipe: + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + hostname: str = "", + port_offset: int = 0, + library_path: Optional[str] = None) -> None: + self.config = config + self.rank = port_offset + self.local_rank = local_rank + self.device = torch.device(f"cuda:{self.local_rank}") + # self.nccl = NCCLLibrary(library_path) + self.flagcx = FLAGCXLibrary("/mine/ip122/tune_qwen/FlagCX/build/lib/libflagcx.so") # /mine/ip122/tune_qwen/FlagCX /workspace/serving/FlagCX + + if not hostname: + hostname = get_ip() + port = self.config.kv_port + port_offset + if port == 0: + raise ValueError("Port cannot be 0") + self._hostname = hostname + self._port = port + + # Each card corresponds to a ZMQ address. + self.zmq_address = f"{self._hostname}:{self._port}" + + # The `http_port` must be consistent with the port of OpenAI. + self.http_address = ( + f"{self._hostname}:" + f"{self.config.kv_connector_extra_config['http_port']}") + + # If `proxy_ip` or `proxy_port` is `""`, + # then the ping thread will not be enabled. + proxy_ip = self.config.get_from_extra_config("proxy_ip", "") + proxy_port = self.config.get_from_extra_config("proxy_port", "") + if proxy_ip == "" or proxy_port == "": + self.proxy_address = "" + else: + self.proxy_address = proxy_ip + ":" + proxy_port + + self.context = zmq.Context() + self.router_socket = self.context.socket(zmq.ROUTER) + self.router_socket.bind(f"tcp://{self.zmq_address}") + + self.poller = zmq.Poller() + self.poller.register(self.router_socket, zmq.POLLIN) + + self.send_store_cv = threading.Condition() + self.send_queue_cv = threading.Condition() + self.recv_store_cv = threading.Condition() + self.comm_cv = threading.Condition() + + # The sending type includes tree mutually exclusive options: + # PUT, GET, PUT_ASYNC. + self.send_type = self.config.get_from_extra_config("send_type", "PUT") + if self.send_type == "GET": + self.send_store: Dict[str, + torch.Tensor] = {} # tensor_id: torch.Tensor + else: + # PUT or PUT_ASYNC + self.send_queue: Deque[ + List[Any]] = deque() # tensor_id: torch.Tensor + if self.send_type == "PUT_ASYNC": + self._send_thread = threading.Thread(target=self._send_async, + daemon=True) + self._send_thread.start() + + self.recv_store: Dict[str, + torch.Tensor] = {} # tensor_id: torch.Tensor + self.socks: Dict[str, Any] = {} # remote_address: client socket + self.comms: Dict[str, Any] = {} # remote_address: (ncclComm_t, rank) + + self.buffer_size = 0 + self.buffer_size_threshold = self.config.kv_buffer_size + + self._listener_thread = threading.Thread( + target=self._listen_for_requests, daemon=True) + self._listener_thread.start() + + self._ping_thread = None + if port_offset == 0 and self.proxy_address != "": + self._ping_thread = threading.Thread(target=self._ping, + daemon=True) + self._ping_thread.start() + + def _create_connect(self, remote_address: typing.Optional[str] = None): + assert remote_address is not None + if remote_address not in self.socks: + sock = self.context.socket(zmq.DEALER) + sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) + sock.connect(f"tcp://{remote_address}") + self.socks[remote_address] = sock + if remote_address in self.comms: + logger.info("👋comm exists, remote_address:%s, comms:%s", + remote_address, self.comms) + return sock, self.comms[remote_address] + + # unique_id = self.nccl.ncclGetUniqueId() + unique_id = self.flagcx.flagcxGetUniqueId().contents + # uid = unique_id.contents + data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)} + sock.send(msgpack.dumps(data)) + + with torch.cuda.device(self.device): + rank = 0 + # comm: ncclComm_t = self.nccl.ncclCommInitRank( + # 2, unique_id, rank) + comm = self.flagcx.flagcxCommInitRank( + 2, ctypes.byref(unique_id), rank) + self.comms[remote_address] = (comm, rank) + logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s", + self.zmq_address, remote_address, rank) + + return self.socks[remote_address], self.comms[remote_address] + + def send_tensor( + self, + tensor_id: str, + tensor: torch.Tensor, + remote_address: typing.Optional[str] = None, + ) -> bool: + if remote_address is None: + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.recv_store_cv.notify() + return True + else: + if self.send_type == "PUT": + return self._send_sync(tensor_id, tensor, remote_address) + elif self.send_type == "PUT_ASYNC": + with self.send_queue_cv: + self.send_queue.append([tensor_id, remote_address, tensor]) + self.send_queue_cv.notify() + else: # GET + with self.send_store_cv: + tensor_size = tensor.element_size() * tensor.numel() + while (self.buffer_size + tensor_size + > self.buffer_size_threshold): + oldest_tenser_id = next(iter(self.send_store)) + oldest_tenser = self.send_store.pop(oldest_tenser_id) + oldest_tenser_size = oldest_tenser.element_size( + ) * oldest_tenser.numel() + self.buffer_size -= oldest_tenser_size + logger.info( + "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," + " buffer_size:%d, oldest_tenser_size:%d, rank:%d", + remote_address, tensor_id, tensor_size, + self.buffer_size, oldest_tenser_size, self.rank) + + self.send_store[tensor_id] = tensor + self.buffer_size += tensor_size + logger.info( + "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " + "shape:%s, rank:%d, buffer_size:%d(%.2f%%)", + remote_address, tensor_id, tensor_size, tensor.shape, + self.rank, self.buffer_size, + self.buffer_size / self.buffer_size_threshold * 100) + + return True + + def recv_tensor( + self, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ) -> torch.Tensor: + if self.send_type == "PUT" or self.send_type == "PUT_ASYNC": + start_time = time.time() + with self.recv_store_cv: + while tensor_id not in self.recv_store: + self.recv_store_cv.wait() + tensor = self.recv_store[tensor_id] + self.recv_store[tensor_id] = None + while len(self.recv_store) > 10000: + self.recv_store.pop(next(iter(self.recv_store))) + + duration = time.time() - start_time + if tensor is not None: + self.buffer_size -= (tensor.element_size() * tensor.numel()) + logger.info( + "🔵[PUT]Recv From %s, tensor_id:%s, shape:%s, " + "duration:%.3fms, size:%.3fGB, rank:%d", remote_address, + tensor_id, tensor.shape, duration * 1000, + tensor.element_size() * tensor.numel() / 1024**3, + self.rank) + else: + logger.warning( + "🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, " + "rank:%d", remote_address, tensor_id, duration * 1000, + self.rank) + return tensor + + # GET + if remote_address is None: + return None + + if remote_address not in self.socks: + self._create_connect(remote_address) + + sock = self.socks[remote_address] + comm, rank = self.comms[remote_address] + + data = {"cmd": "GET", "tensor_id": tensor_id} + sock.send(msgpack.dumps(data)) + + message = sock.recv() + data = msgpack.loads(message) + if data["ret"] != 0: + logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d", + remote_address, tensor_id, data["ret"]) + return None + + tensor = torch.empty(data["shape"], + dtype=getattr(torch, data["dtype"]), + device=self.device) + + start_time = time.time() + self._recv(comm, tensor, rank ^ 1) + duration = time.time() - start_time + logger.info( + "🔵[GET]Recv From %s, tensor_id:%s, shape:%s, duration:%.3fms, " + "size:%.3fGB, rank:%d", remote_address, tensor_id, tensor.shape, + duration * 1000, + tensor.element_size() * tensor.numel() / 1024**3, self.rank) + + return tensor + + def _listen_for_requests(self): + while True: + socks = dict(self.poller.poll()) + if self.router_socket in socks: + remote_address, message = self.router_socket.recv_multipart() + data = msgpack.loads(message) + logger.debug("Received message from %s, data:%s", + remote_address.decode(), data) + if data["cmd"] == "NEW": + unique_id = self.flagcx.unique_id_from_bytes( + bytes(data["unique_id"])) + with torch.cuda.device(self.device): + rank = 1 + # comm: ncclComm_t = self.nccl.ncclCommInitRank( + # 2, unique_id, rank) + comm = self.flagcx.flagcxCommInitRank( + 2, ctypes.byref(unique_id), rank) + self.comms[remote_address.decode()] = (comm, rank) + logger.info( + "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", + self.zmq_address, remote_address.decode(), rank) + elif data["cmd"] == "PUT": + tensor_id = data["tensor_id"] + try: + tensor = torch.empty(data["shape"], + dtype=getattr( + torch, data["dtype"]), + device=self.device) + + tensor_size = tensor.element_size() * tensor.numel() + if (self.buffer_size + tensor_size + > self.buffer_size_threshold): + self.router_socket.send_multipart( + [remote_address, b"2"]) + logger.warning( + "🔴[PUT]Recv Tensor, Out Of Threshold, " + "%s👈%s, data:%s", self.zmq_address, + remote_address.decode(), data) + tensor = None + else: + self.buffer_size += tensor_size + self.router_socket.send_multipart( + [remote_address, b"0"]) + comm, rank = self.comms[remote_address.decode()] + self._recv(comm, tensor, rank ^ 1) + logger.info( + "🔵[PUT]Recv Tensor, %s👈%s, MyRank:%s, " + "data:%s, shape:%s", self.zmq_address, + remote_address.decode(), rank, data, + tensor.shape) + + except torch.cuda.OutOfMemoryError: + self.router_socket.send_multipart( + [remote_address, b"1"]) + tensor = None + logger.warning( + "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, " + "data:%s", self.zmq_address, + remote_address.decode(), data) + + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.recv_store_cv.notify() + + elif data["cmd"] == "GET": + tensor_id = data["tensor_id"] + with self.send_store_cv: + tensor = self.send_store.pop(tensor_id, None) + if tensor is not None: + data = { + "ret": 0, + "shape": tensor.shape, + "dtype": + str(tensor.dtype).replace("torch.", "") + } + # LRU + self.send_store[tensor_id] = tensor + else: + data = {"ret": 1} + + self.router_socket.send_multipart( + [remote_address, msgpack.dumps(data)]) + + if data["ret"] == 0: + self._send(comm, tensor.to(self.device), rank ^ 1) + + logger.info( + "🔵[GET]Send Tensor, %s👉%s, " + "MyRank:%s, data:%s", self.zmq_address, + remote_address.decode(), rank, data) + else: + logger.warning( + "🚧Unexpected, Received message from %s, data:%s", + remote_address, data) + + # Asynchronous sending may cause conflicts between P2P NCCL and + # NCCL used in TP/PP, which can lead to deadlock issues. + def _send_async(self): + while True: + with self.send_queue_cv: + while not self.send_queue: + self.send_queue_cv.wait() + tensor_id, remote_address, tensor = self.send_queue.popleft() + if not self.send_queue: + self.send_queue_cv.notify() + self._send_sync(tensor_id, tensor, remote_address) + + def wait_for_sent(self): + if self.send_type == "PUT_ASYNC": + start_time = time.time() + with self.send_queue_cv: + while self.send_queue: + self.send_queue_cv.wait() + duration = time.time() - start_time + logger.info( + "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue" + " to be empty, rank:%d", duration * 1000, self.rank) + + def _send_sync( + self, + tensor_id: str, + tensor: torch.Tensor, + remote_address: typing.Optional[str] = None, + ) -> bool: + if remote_address is None: + return False + if remote_address not in self.socks: + self._create_connect(remote_address) + + sock = self.socks[remote_address] + comm, rank = self.comms[remote_address] + data = { + "cmd": "PUT", + "tensor_id": tensor_id, + "shape": tensor.shape, + "dtype": str(tensor.dtype).replace("torch.", "") + } + sock.send(msgpack.dumps(data)) + + response = sock.recv() + if response != b"0": + # with self.send_queue_cv: + # self.send_queue.append([tensor_id, remote_address, tensor]) + # self.send_queue_cv.notify() + logger.warning( + "🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, " + "MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s", + self.zmq_address, remote_address, rank, data, tensor.shape, + tensor.element_size() * tensor.numel() / 1024**3, + response.decode()) + return False + + self._send(comm, tensor.to(self.device), rank ^ 1) + logger.info("🔵Send Tensor, %s👉%s, MyRank:%s, data:%s, tensor:%s", + self.zmq_address, remote_address, rank, data, tensor.shape) + return True + + def _ping(self): + sock = self.context.socket(zmq.DEALER) + sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) + logger.debug("ping start, zmq_address:%s", self.zmq_address) + sock.connect(f"tcp://{self.proxy_address}") + data = { + "type": "P" if self.config.is_kv_producer else "D", + "http_address": self.http_address, + "zmq_address": self.zmq_address + } + while True: + sock.send(msgpack.dumps(data)) + time.sleep(3) + + def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None): + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + + with self.comm_cv: + # self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), + # ncclDataTypeEnum.from_torch(tensor.dtype), dst, + # comm, cudaStream_t(stream.cuda_stream)) + # self.nccl.flagcxSend(buffer_type(tensor.data_ptr()), tensor.numel(), + # flagcxDataTypeEnum.from_torch(tensor.dtype), dst, + # comm, cudaStream_t(stream.cuda_stream)) + flagcx_stream = self.flagcx.adaptor_stream_copy(stream) + self.flagcx.flagcxSend(buffer_type(tensor.data_ptr()), tensor.numel(), + flagcxDataTypeEnum.from_torch(tensor.dtype), dst, + comm, flagcx_stream) + self.flagcx.adaptor_stream_free(flagcx_stream) + + def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None): + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + + with self.comm_cv: + # self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + # ncclDataTypeEnum.from_torch(tensor.dtype), src, + # comm, cudaStream_t(stream.cuda_stream)) + # self.flagcx.flagcxRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + # flagcxDataTypeEnum.from_torch(tensor.dtype), src, + # comm, cudaStream_t(stream.cuda_stream)) + flagcx_stream = self.flagcx.adaptor_stream_copy(stream) + self.flagcx.flagcxRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + flagcxDataTypeEnum.from_torch(tensor.dtype), src, + comm, flagcx_stream) + self.flagcx.adaptor_stream_free(flagcx_stream) + + def close(self) -> None: + self._listener_thread.join() + if self.send_type == "PUT_ASYNC": + self._send_thread.join() + if self._ping_thread is not None: + self._ping_thread.join() \ No newline at end of file diff --git a/third_party/vllm b/third_party/vllm index dc1b4a6f1..139e70670 160000 --- a/third_party/vllm +++ b/third_party/vllm @@ -1 +1 @@ -Subproject commit dc1b4a6f1300003ae27f033afbdff5e2683721ce +Subproject commit 139e706706402859fd55381a39dd7182f0537977 From f497249d92eed8f4bacc69deb1d1fa0c3a777e4f Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 10:30:11 +0800 Subject: [PATCH 49/62] polish code --- .../config_qwen2.5_7b_pd_disaggregation.yaml | 6 +-- examples/qwen/conf/hostfile.txt | 4 +- flagscale/runner/runner_serve.py | 29 ++++---------- flagscale/serve/run_pd_disagg_router.py | 40 ++++++++----------- 4 files changed, 29 insertions(+), 50 deletions(-) diff --git a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml index 2f0e42fdf..aefe49c83 100644 --- a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml +++ b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml @@ -11,10 +11,10 @@ experiment: port: 10001 use_fs_serve: false prefill_decode_disaggregation: true - prefill_num: 1 - prefill_address: 10.1.1.122 # optional, default "auto" + prefill_num: 2 + prefill_address: x.x.x.x # optional, default "auto" decode_num: 2 - decode_address: 10.1.1.108 # optional, default "auto" + decode_address: x.x.x.x # optional, default "auto" runner: hostfile: examples/qwen/conf/hostfile.txt docker: fr-v2 diff --git a/examples/qwen/conf/hostfile.txt b/examples/qwen/conf/hostfile.txt index e49bb3049..0d8b1e05f 100644 --- a/examples/qwen/conf/hostfile.txt +++ b/examples/qwen/conf/hostfile.txt @@ -1,5 +1,5 @@ # ip slots type=xxx[optional] # master node -10.1.1.122 slots=8 type=gpu +x.x.x.x slots=8 type=gpu # worker nodes -10.1.1.108 slots=8 type=gpu +x.x.x.x slots=8 type=gpu diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index cf2eba568..6dbd0fde9 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -294,7 +294,6 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi vllm_path = f"{root_dir}/vllm" deploy_config = config.experiment.get("deploy", {}) envs = config.experiment.get("envs", {}) - print(f"shell file ======================== {host_run_script_file}", flush=True) with open(host_run_script_file, "w") as f: f.write("#!/bin/bash\n\n") f.write("set -x\n") @@ -321,9 +320,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi kv_related_ports = _get_multiple_free_ports(ports_num) pd_proxy_port = deploy_config.get("pd_proxy_port", None) if not pd_proxy_port: - raise ValueError( - f"PD disaggregation requires a proxy port to be set." - ) + raise ValueError(f"PD disaggregation requires a proxy port to be set.") engine_args = _get_engine_args(config) command_items = ["vllm", "serve"] @@ -331,7 +328,6 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi other_args = flatten_dict_to_args(engine_args, ["model", "port"]) command_items.extend(other_args) vllm_command = " ".join(command_items) - # vllm_command = "nohup " + vllm_command if before_start_cmd: vllm_command = f"{before_start_cmd} && " + vllm_command if envs_str: @@ -386,21 +382,18 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi "http_port": str(http_port), }, } - print( + logger.info( f"============= prefill instance {i}, p_kv_config: {p_kv_config} =============", flush=True, ) card_ids = resource_manager.get_available_card_ids( - address=p_address, - num=each_instance_card_num, + address=p_address, num=each_instance_card_num ) card_ids_str = ",".join(map(str, card_ids)) ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}" p_kv_config_json = json.dumps(p_kv_config) - p_instance_log_path = os.path.join( - default_log_dir, f"prefill_{i}.log" - ) + p_instance_log_path = os.path.join(default_log_dir, f"prefill_{i}.log") if p_address != master_ip: p_kv_config_formate_json = p_kv_config_json.replace('"', '\\"') @@ -433,21 +426,18 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi "http_port": str(http_port), }, } - print( + logger.info( f"============= decode instance {i}, d_kv_config: {d_kv_config} =============", flush=True, ) card_ids = resource_manager.get_available_card_ids( - address=d_address, - num=each_instance_card_num, + address=d_address, num=each_instance_card_num ) card_ids_str = ",".join(map(str, card_ids)) ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}" d_kv_config_json = json.dumps(d_kv_config) - d_instance_log_path = os.path.join( - default_log_dir, f"decode_{j}.log" - ) + d_instance_log_path = os.path.join(default_log_dir, f"decode_{j}.log") if d_address != master_ip: d_kv_config_formate_json = d_kv_config_json.replace('"', '\\"') @@ -683,9 +673,7 @@ def _prepare(self): self.user_envs = self.config.experiment.get("envs", {}) entrypoint = self.config.experiment.task.get("entrypoint", None) if self.inference_engine: # pd_disagg_router - if self.config.experiment.get("deploy", {}).get( - "prefill_decode_disaggregation", False - ): + if self.config.experiment.get("deploy", {}).get("prefill_decode_disaggregation", False): self.user_script = "flagscale/serve/run_pd_disagg_router.py" elif not self.use_fs_serve: self.user_script = "flagscale/serve/run_inference_engine.py" @@ -783,7 +771,6 @@ def _stop_each(self, host, node_rank): kill_process_tree(pid) ray_executable = shutil.which("ray") - print(ray_executable) if ray_executable: ray_path = os.path.realpath(ray_executable) os.system(f"{ray_path} stop") diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disagg_router.py index 8e45a443b..3dd35d911 100644 --- a/flagscale/serve/run_pd_disagg_router.py +++ b/flagscale/serve/run_pd_disagg_router.py @@ -1,4 +1,11 @@ -import logging +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Adopted from https://github.com/vllm-project/vllm/blob/1ad957950ffc1552af5abda78c03d88ddb67945b/examples/online_serving/disagg_xpyd/disagg_prefill_proxy_xpyd.py. Below is the original copyright: +# +# SPDX-License-Identifier: Apache-2.0 +# + + import os import random import socket @@ -8,6 +15,7 @@ import aiohttp import msgpack import zmq + from quart import Quart, make_response, request try: @@ -29,18 +37,13 @@ class LoadManager: def __init__(self): self._lock = threading.Lock() # Each resource type 'P' or 'D' maps to {http_addr: {'zmq': zmq_addr, 'load': int}} - self._instances: dict[str, dict[str, dict[str, object]]] = { - "P": {}, - "D": {}, - } + self._instances: dict[str, dict[str, dict[str, object]]] = {"P": {}, "D": {}} def register(self, rtype: str, http_addr: str, zmq_addr: str): with self._lock: if http_addr not in self._instances[rtype]: self._instances[rtype][http_addr] = {"zmq": zmq_addr, "load": 0} - logger.info( - f"Registered new {rtype}-instance {http_addr} (zmq={zmq_addr})" - ) + logger.info(f"Registered new {rtype}-instance {http_addr} (zmq={zmq_addr})") else: # If zmq address changed, synchronize it self._instances[rtype][http_addr]["zmq"] = zmq_addr @@ -67,13 +70,8 @@ def get_random(self, rtype: str) -> tuple[str, str]: def get_robin_loaded(self, rtype: str) -> tuple[str, str]: with self._lock: - http_addr, info = min( - self._instances[rtype].items(), key=lambda kv: kv[1]["load"] - ) - print( - f"========== whole instance status {self._instances}==========", - flush=True, - ) + http_addr, info = min(self._instances[rtype].items(), key=lambda kv: kv[1]["load"]) + print(f"========== whole instance status {self._instances}==========", flush=True) return http_addr, info["zmq"] @@ -168,9 +166,7 @@ async def forward_request(url, data, request_id): async def handle_request(): try: original_data = await request.get_json() - endpoint = ( - request.path - ) # this will be '/v1/completions' or '/v1/chat/completions' + endpoint = request.path # this will be '/v1/completions' or '/v1/chat/completions' # Prefill request: max_tokens=1 prefill_request = original_data.copy() @@ -191,9 +187,7 @@ async def handle_request(): logger.info(f"Selected D-instance {decode_addr} via '{SCHEDULING_STRATEGY}'") # Keep original request_id composition format - request_id = ( - f"___prefill_addr_{prefill_zmq}___decode_addr_{decode_zmq}_{random_uuid()}" - ) + request_id = f"___prefill_addr_{prefill_zmq}___decode_addr_{decode_zmq}_{random_uuid()}" # Execute Prefill and update load lm.increment_load("P", prefill_addr) @@ -235,9 +229,7 @@ def main(): raise ValueError("No port specified in deploy config") if not pd_proxy_port: raise ValueError("No pd_proxy_port specified in deploy config") - print( - f"Starting Proxy Server...with pd_proxy_port {pd_proxy_port} and serve_port {serve_port}" - ) + print(f"Starting Proxy Server...with pd_proxy_port {pd_proxy_port} and serve_port {serve_port}") listener = start_service_discovery("0.0.0.0", pd_proxy_port) app.run(host="0.0.0.0", port=serve_port) listener.join() From 89cbc650467315dbf3568be33e233487d0182a35 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 11:17:09 +0800 Subject: [PATCH 50/62] polish copyright --- .../device_communicators/flagcx_wrapper.py | 4 +- .../device_communicators/pynccl_wrapper.py | 3 + .../kv_transfer/kv_connector/factory.py | 2 + .../kv_transfer/kv_connector/p2p_connector.py | 9 +- .../kv_pipe/flagcx_p2p_nccl_pipe.py | 99 ++++++++++++------- .../kv_transfer/kv_pipe/p2p_nccl_pipe.py | 85 ++++++---------- 6 files changed, 104 insertions(+), 98 deletions(-) diff --git a/flagscale/backends/vllm/vllm/distributed/device_communicators/flagcx_wrapper.py b/flagscale/backends/vllm/vllm/distributed/device_communicators/flagcx_wrapper.py index 175ebe805..634db5852 100644 --- a/flagscale/backends/vllm/vllm/distributed/device_communicators/flagcx_wrapper.py +++ b/flagscale/backends/vllm/vllm/distributed/device_communicators/flagcx_wrapper.py @@ -1,6 +1,6 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# Adopted from https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/pynccl_wrapper.py. Below is the original copyright: # SPDX-License-Identifier: Apache-2.0 -# reference https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/pynccl_wrapper.py - import ctypes import platform from dataclasses import dataclass diff --git a/flagscale/backends/vllm/vllm/distributed/device_communicators/pynccl_wrapper.py b/flagscale/backends/vllm/vllm/distributed/device_communicators/pynccl_wrapper.py index 83619c27f..7b490ca56 100644 --- a/flagscale/backends/vllm/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/flagscale/backends/vllm/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -1,3 +1,6 @@ +# Copied from https://github.com/vllm-project/vllm/blob/1ad957950ffc1552af5abda78c03d88ddb67945b/vllm/distributed/device_communicators/pynccl_wrapper.py. +# Below is the original copyright: + # SPDX-License-Identifier: Apache-2.0 # This file is a pure Python wrapper for the NCCL library. diff --git a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/factory.py b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/factory.py index 39c9809ce..52de6757e 100644 --- a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -1,3 +1,5 @@ +# Copied from https://github.com/vllm-project/vllm/blob/1ad957950ffc1552af5abda78c03d88ddb67945b/vllm/distributed/kv_transfer/kv_connector/factory.py. +# Below is the original copyright: # SPDX-License-Identifier: Apache-2.0 import importlib diff --git a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py index b0e2d96ba..56b2ba9c8 100644 --- a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py +++ b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py @@ -1,5 +1,7 @@ +# Mainly adopted from https://github.com/FlagOpen/FlagScale/blob/44ceca57dd6f86b10163968e617497c613e47d6e/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py. +# Below is the original copyright: # SPDX-License-Identifier: Apache-2.0 - +import os import re from typing import TYPE_CHECKING, List, Tuple, Union @@ -9,7 +11,10 @@ from vllm import _custom_ops as ops from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.distributed.kv_transfer.kv_pipe.p2p_nccl_pipe import P2pNcclPipe +if os.getenv('USE)FLAGCX', 'False').lower() == 'true': + from vllm.distributed.kv_transfer.kv_pipe.flagcx_p2p_nccl_pipe import P2pNcclPipe +else: + from vllm.distributed.kv_transfer.kv_pipe.p2p_nccl_pipe import P2pNcclPipe from vllm.logger import init_logger from vllm.sequence import IntermediateTensors diff --git a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py index e62a5f656..b407b54df 100644 --- a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py +++ b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py @@ -1,5 +1,7 @@ +# Mainly adopted from https://github.com/vllm-project/vllm/blob/1ad957950ffc1552af5abda78c03d88ddb67945b/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py. +# Below is the original copyright: # SPDX-License-Identifier: Apache-2.0 - +import os import logging import threading import time @@ -10,15 +12,16 @@ import msgpack import torch import zmq -import os +import ctypes from vllm.config import KVTransferConfig - - from vllm.distributed.device_communicators.flagcx_wrapper import ( - FLAGCXLibrary, buffer_type, flagcxComm_t, flagcxDataTypeEnum, - flagcxRedOpTypeEnum, flagcxUniqueId, cudaStream_t) - + FLAGCXLibrary, + buffer_type, + cudaStream_t, + flagcxComm_t, + flagcxDataTypeEnum, +) from vllm.utils import current_stream, get_ip logger = logging.getLogger(__name__) @@ -38,8 +41,7 @@ def __init__(self, self.device = torch.device(f"cuda:{self.local_rank}") flagcx_path = os.getenv('FLAGCX_PATH') library_path=os.path.join(flagcx_path, "build/lib/libflagcx.so") - print("============== flagcx library_path ============", library_path, flush=True) - self.nccl = FLAGCXLibrary(library_path) + self.flagcx = FLAGCXLibrary(library_path) if not hostname: hostname = get_ip() @@ -96,7 +98,7 @@ def __init__(self, self.recv_store: Dict[str, torch.Tensor] = {} # tensor_id: torch.Tensor self.socks: Dict[str, Any] = {} # remote_address: client socket - self.comms: Dict[str, Any] = {} # remote_address: (flagcxComm_t, rank) + self.comms: Dict[str, Any] = {} # remote_address: (ncclComm_t, rank) self.buffer_size = 0 self.buffer_size_threshold = self.config.kv_buffer_size @@ -123,16 +125,16 @@ def _create_connect(self, remote_address: typing.Optional[str] = None): remote_address, self.comms) return sock, self.comms[remote_address] - unique_id = self.nccl.flagcxGetUniqueId() + unique_id = self.flagcx.flagcxGetUniqueId().contents data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)} sock.send(msgpack.dumps(data)) with torch.cuda.device(self.device): rank = 0 - comm: flagcxComm_t = self.nccl.flagcxCommInitRank( - 2, unique_id, rank) + comm = self.flagcx.flagcxCommInitRank( + 2, ctypes.byref(unique_id), rank) self.comms[remote_address] = (comm, rank) - logger.info("🤝flagcxCommInitRank Success, %s👉%s, MyRank: %s", + logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s", self.zmq_address, remote_address, rank) return self.socks[remote_address], self.comms[remote_address] @@ -192,10 +194,11 @@ def recv_tensor( with self.recv_store_cv: while tensor_id not in self.recv_store: self.recv_store_cv.wait() - # TODO:Abatom, To avoid an overly large dictionary. - # tensor = self.recv_store.pop(tensor_id) tensor = self.recv_store[tensor_id] self.recv_store[tensor_id] = None + while len(self.recv_store) > 10000: + self.recv_store.pop(next(iter(self.recv_store))) + duration = time.time() - start_time if tensor is not None: self.buffer_size -= (tensor.element_size() * tensor.numel()) @@ -256,17 +259,20 @@ def _listen_for_requests(self): logger.debug("Received message from %s, data:%s", remote_address.decode(), data) if data["cmd"] == "NEW": - unique_id = self.nccl.unique_id_from_bytes( + unique_id = self.flagcx.unique_id_from_bytes( bytes(data["unique_id"])) with torch.cuda.device(self.device): rank = 1 - comm: flagcxComm_t = self.nccl.flagcxCommInitRank( - 2, unique_id, rank) + # comm: ncclComm_t = self.nccl.ncclCommInitRank( + # 2, unique_id, rank) + comm = self.flagcx.flagcxCommInitRank( + 2, ctypes.byref(unique_id), rank) self.comms[remote_address.decode()] = (comm, rank) logger.info( - "🤝flagcxCommInitRank Success, %s👈%s, MyRank:%s", + "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", self.zmq_address, remote_address.decode(), rank) elif data["cmd"] == "PUT": + tensor_id = data["tensor_id"] try: tensor = torch.empty(data["shape"], dtype=getattr( @@ -290,23 +296,24 @@ def _listen_for_requests(self): comm, rank = self.comms[remote_address.decode()] self._recv(comm, tensor, rank ^ 1) logger.info( - "🔵[PUT]Recv Tensor, %s👈%s, MyRank:%s, data:%s, " - "shape:%s", self.zmq_address, - remote_address.decode(), rank, data, tensor.shape) - - tensor_id = data["tensor_id"] - with self.recv_store_cv: - self.recv_store[tensor_id] = tensor - self.recv_store_cv.notify() + "🔵[PUT]Recv Tensor, %s👈%s, MyRank:%s, " + "data:%s, shape:%s", self.zmq_address, + remote_address.decode(), rank, data, + tensor.shape) except torch.cuda.OutOfMemoryError: self.router_socket.send_multipart( [remote_address, b"1"]) + tensor = None logger.warning( "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, " "data:%s", self.zmq_address, remote_address.decode(), data) + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.recv_store_cv.notify() + elif data["cmd"] == "GET": tensor_id = data["tensor_id"] with self.send_store_cv: @@ -346,8 +353,21 @@ def _send_async(self): while not self.send_queue: self.send_queue_cv.wait() tensor_id, remote_address, tensor = self.send_queue.popleft() + if not self.send_queue: + self.send_queue_cv.notify() self._send_sync(tensor_id, tensor, remote_address) + def wait_for_sent(self): + if self.send_type == "PUT_ASYNC": + start_time = time.time() + with self.send_queue_cv: + while self.send_queue: + self.send_queue_cv.wait() + duration = time.time() - start_time + logger.info( + "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue" + " to be empty, rank:%d", duration * 1000, self.rank) + def _send_sync( self, tensor_id: str, @@ -403,32 +423,35 @@ def _ping(self): def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None): assert tensor.device == self.device, ( - f"this flagcx communicator is created to work on {self.device}, " + f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") if stream is None: stream = current_stream() with self.comm_cv: - self.nccl.flagcxSend(buffer_type(tensor.data_ptr()), tensor.numel(), - flagcxDataTypeEnum.from_torch(tensor.dtype), dst, - comm, cudaStream_t(stream.cuda_stream)) + flagcx_stream = self.flagcx.adaptor_stream_copy(stream) + self.flagcx.flagcxSend(buffer_type(tensor.data_ptr()), tensor.numel(), + flagcxDataTypeEnum.from_torch(tensor.dtype), dst, + comm, flagcx_stream) + self.flagcx.adaptor_stream_free(flagcx_stream) def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None): assert tensor.device == self.device, ( - f"this flagcx communicator is created to work on {self.device}, " + f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") if stream is None: stream = current_stream() with self.comm_cv: - self.nccl.flagcxRecv(buffer_type(tensor.data_ptr()), tensor.numel(), - flagcxDataTypeEnum.from_torch(tensor.dtype), src, - comm, cudaStream_t(stream.cuda_stream)) + flagcx_stream = self.flagcx.adaptor_stream_copy(stream) + self.flagcx.flagcxRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + flagcxDataTypeEnum.from_torch(tensor.dtype), src, + comm, flagcx_stream) + self.flagcx.adaptor_stream_free(flagcx_stream) def close(self) -> None: self._listener_thread.join() if self.send_type == "PUT_ASYNC": self._send_thread.join() if self._ping_thread is not None: - self._ping_thread.join() - + self._ping_thread.join() \ No newline at end of file diff --git a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py index c99dfbad2..2451cd317 100644 --- a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py +++ b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py @@ -1,5 +1,6 @@ +# Copied adopted from https://github.com/vllm-project/vllm/blob/1ad957950ffc1552af5abda78c03d88ddb67945b/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py. +# Below is the original copyright: # SPDX-License-Identifier: Apache-2.0 - import logging import threading import time @@ -10,18 +11,10 @@ import msgpack import torch import zmq -import ctypes from vllm.config import KVTransferConfig -# from vllm.distributed.device_communicators.pynccl_wrapper import ( -# NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum) -from vllm.distributed.device_communicators.flagcx_wrapper import ( - FLAGCXLibrary, - buffer_type, - cudaStream_t, - flagcxComm_t, - flagcxDataTypeEnum, -) +from vllm.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum) from vllm.utils import current_stream, get_ip logger = logging.getLogger(__name__) @@ -39,8 +32,7 @@ def __init__(self, self.rank = port_offset self.local_rank = local_rank self.device = torch.device(f"cuda:{self.local_rank}") - # self.nccl = NCCLLibrary(library_path) - self.flagcx = FLAGCXLibrary("/mine/ip122/tune_qwen/FlagCX/build/lib/libflagcx.so") # /mine/ip122/tune_qwen/FlagCX /workspace/serving/FlagCX + self.nccl = NCCLLibrary(library_path) if not hostname: hostname = get_ip() @@ -77,7 +69,9 @@ def __init__(self, self.send_store_cv = threading.Condition() self.send_queue_cv = threading.Condition() self.recv_store_cv = threading.Condition() - self.comm_cv = threading.Condition() + + self.send_stream = torch.cuda.Stream() + self.recv_stream = torch.cuda.Stream() # The sending type includes tree mutually exclusive options: # PUT, GET, PUT_ASYNC. @@ -124,18 +118,14 @@ def _create_connect(self, remote_address: typing.Optional[str] = None): remote_address, self.comms) return sock, self.comms[remote_address] - # unique_id = self.nccl.ncclGetUniqueId() - unique_id = self.flagcx.flagcxGetUniqueId().contents - # uid = unique_id.contents + unique_id = self.nccl.ncclGetUniqueId() data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)} sock.send(msgpack.dumps(data)) with torch.cuda.device(self.device): rank = 0 - # comm: ncclComm_t = self.nccl.ncclCommInitRank( - # 2, unique_id, rank) - comm = self.flagcx.flagcxCommInitRank( - 2, ctypes.byref(unique_id), rank) + comm: ncclComm_t = self.nccl.ncclCommInitRank( + 2, unique_id, rank) self.comms[remote_address] = (comm, rank) logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s", self.zmq_address, remote_address, rank) @@ -243,7 +233,7 @@ def recv_tensor( device=self.device) start_time = time.time() - self._recv(comm, tensor, rank ^ 1) + self._recv(comm, tensor, rank ^ 1, self.recv_stream) duration = time.time() - start_time logger.info( "🔵[GET]Recv From %s, tensor_id:%s, shape:%s, duration:%.3fms, " @@ -262,14 +252,12 @@ def _listen_for_requests(self): logger.debug("Received message from %s, data:%s", remote_address.decode(), data) if data["cmd"] == "NEW": - unique_id = self.flagcx.unique_id_from_bytes( + unique_id = self.nccl.unique_id_from_bytes( bytes(data["unique_id"])) with torch.cuda.device(self.device): rank = 1 - # comm: ncclComm_t = self.nccl.ncclCommInitRank( - # 2, unique_id, rank) - comm = self.flagcx.flagcxCommInitRank( - 2, ctypes.byref(unique_id), rank) + comm: ncclComm_t = self.nccl.ncclCommInitRank( + 2, unique_id, rank) self.comms[remote_address.decode()] = (comm, rank) logger.info( "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", @@ -297,7 +285,8 @@ def _listen_for_requests(self): self.router_socket.send_multipart( [remote_address, b"0"]) comm, rank = self.comms[remote_address.decode()] - self._recv(comm, tensor, rank ^ 1) + self._recv(comm, tensor, rank ^ 1, + self.recv_stream) logger.info( "🔵[PUT]Recv Tensor, %s👈%s, MyRank:%s, " "data:%s, shape:%s", self.zmq_address, @@ -337,7 +326,9 @@ def _listen_for_requests(self): [remote_address, msgpack.dumps(data)]) if data["ret"] == 0: - self._send(comm, tensor.to(self.device), rank ^ 1) + comm, rank = self.comms[remote_address.decode()] + self._send(comm, tensor.to(self.device), rank ^ 1, + self.send_stream) logger.info( "🔵[GET]Send Tensor, %s👉%s, " @@ -348,8 +339,6 @@ def _listen_for_requests(self): "🚧Unexpected, Received message from %s, data:%s", remote_address, data) - # Asynchronous sending may cause conflicts between P2P NCCL and - # NCCL used in TP/PP, which can lead to deadlock issues. def _send_async(self): while True: with self.send_queue_cv: @@ -405,7 +394,7 @@ def _send_sync( response.decode()) return False - self._send(comm, tensor.to(self.device), rank ^ 1) + self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream) logger.info("🔵Send Tensor, %s👉%s, MyRank:%s, data:%s, tensor:%s", self.zmq_address, remote_address, rank, data, tensor.shape) return True @@ -431,18 +420,10 @@ def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None): if stream is None: stream = current_stream() - with self.comm_cv: - # self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), - # ncclDataTypeEnum.from_torch(tensor.dtype), dst, - # comm, cudaStream_t(stream.cuda_stream)) - # self.nccl.flagcxSend(buffer_type(tensor.data_ptr()), tensor.numel(), - # flagcxDataTypeEnum.from_torch(tensor.dtype), dst, - # comm, cudaStream_t(stream.cuda_stream)) - flagcx_stream = self.flagcx.adaptor_stream_copy(stream) - self.flagcx.flagcxSend(buffer_type(tensor.data_ptr()), tensor.numel(), - flagcxDataTypeEnum.from_torch(tensor.dtype), dst, - comm, flagcx_stream) - self.flagcx.adaptor_stream_free(flagcx_stream) + with torch.cuda.stream(stream): + self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), dst, + comm, cudaStream_t(stream.cuda_stream)) def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None): assert tensor.device == self.device, ( @@ -451,18 +432,10 @@ def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None): if stream is None: stream = current_stream() - with self.comm_cv: - # self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), - # ncclDataTypeEnum.from_torch(tensor.dtype), src, - # comm, cudaStream_t(stream.cuda_stream)) - # self.flagcx.flagcxRecv(buffer_type(tensor.data_ptr()), tensor.numel(), - # flagcxDataTypeEnum.from_torch(tensor.dtype), src, - # comm, cudaStream_t(stream.cuda_stream)) - flagcx_stream = self.flagcx.adaptor_stream_copy(stream) - self.flagcx.flagcxRecv(buffer_type(tensor.data_ptr()), tensor.numel(), - flagcxDataTypeEnum.from_torch(tensor.dtype), src, - comm, flagcx_stream) - self.flagcx.adaptor_stream_free(flagcx_stream) + with torch.cuda.stream(stream): + self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + comm, cudaStream_t(stream.cuda_stream)) def close(self) -> None: self._listener_thread.join() From 5c65a78af86ece51c30958a07c9c564f9f423753 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 11:25:24 +0800 Subject: [PATCH 51/62] polish code --- .../vllm/distributed/kv_transfer/kv_connector/p2p_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py index 56b2ba9c8..c88444370 100644 --- a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py +++ b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py @@ -11,7 +11,7 @@ from vllm import _custom_ops as ops from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -if os.getenv('USE)FLAGCX', 'False').lower() == 'true': +if os.getenv("USE_FLAGCX", "false").lower() in ("1", "true"): from vllm.distributed.kv_transfer.kv_pipe.flagcx_p2p_nccl_pipe import P2pNcclPipe else: from vllm.distributed.kv_transfer.kv_pipe.p2p_nccl_pipe import P2pNcclPipe From adc03c38a380b89f37ea3d90822d8770e4e1b8e2 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 11:34:44 +0800 Subject: [PATCH 52/62] fix code --- flagscale/runner/runner_serve.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 6dbd0fde9..1b07262cc 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -545,11 +545,10 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi ) if before_start_cmd: node_cmd = f"{before_start_cmd} && " + node_cmd - ssh_cmd = f'ssh -n -p {ssh_port} {ip} "{node_cmd}"' - - if docker_name: - ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" - f.write(f"{ssh_cmd}\n") + ssh_cmd = f'ssh -n -p {ssh_port} {ip} "{node_cmd}"' + if docker_name: + ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" + f.write(f"{ssh_cmd}\n") else: # Note: config key device_type is specified for single node serving in neither gpu or cpu. device_type = None From 0df9f7538063044599f1dbfee12aab048bbd8da1 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 11:58:25 +0800 Subject: [PATCH 53/62] polish code --- flagscale/runner/runner_serve.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 1b07262cc..2b29b18aa 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -383,8 +383,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi }, } logger.info( - f"============= prefill instance {i}, p_kv_config: {p_kv_config} =============", - flush=True, + f"============= prefill instance {i}, p_kv_config: {p_kv_config} =============" ) card_ids = resource_manager.get_available_card_ids( address=p_address, num=each_instance_card_num @@ -427,8 +426,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi }, } logger.info( - f"============= decode instance {i}, d_kv_config: {d_kv_config} =============", - flush=True, + f"============= decode instance {i}, d_kv_config: {d_kv_config} =============" ) card_ids = resource_manager.get_available_card_ids( address=d_address, num=each_instance_card_num From 47ac748bb5cfb426bc6ee8ba020bdfb2881e44d6 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 12:02:37 +0800 Subject: [PATCH 54/62] polish code --- .../qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml | 8 +++++++- examples/qwen/conf/serve/serve_qwen2.5_7b.yaml | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml index aefe49c83..e48008890 100644 --- a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml +++ b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml @@ -20,8 +20,14 @@ experiment: docker: fr-v2 envs: CUDA_DEVICE_MAX_CONNECTIONS: 1 + VLLM_USE_V1: 0 + FLAGCX_SOCKET_IFNAME: bond0 + FLAGCX_PATH: /mine/ip122/tune_qwen/FlagCX/ + FLAGCX_DEBUG: TRACE + FLAGCX_DEBUG_SUBSYS: ALL + USE_FLAGCX: true cmds: - before_start: export VLLM_USE_V1=0 && export FLAGCX_SOCKET_IFNAME=bond0 && export VLLM_USE_V1=0 && source /root/miniconda3/bin/activate flagscale-inference && export FLAGCX_PATH=/mine/ip122/tune_qwen/FlagCX/ && export FLAGCX_DEBUG=TRACE && export FLAGCX_DEBUG_SUBSYS=ALL && export NCCL_DEBUG=TRACE && export NCCL_DEBUG_SUBSYS=ALL + before_start: source /root/miniconda3/bin/activate flagscale-inference action: run diff --git a/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml b/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml index 3a7a20130..49e39288d 100644 --- a/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml +++ b/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml @@ -5,7 +5,7 @@ host: 0.0.0.0 tensor_parallel_size: 1 pipeline_parallel_size: 1 - gpu_memory_utilization: 0.1 + gpu_memory_utilization: 0.9 max_model_len: 32768 max_num_seqs: 256 enforce_eager: true From 01680c536d13136b5f13d378d0a60ece957d7e0f Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 12:25:21 +0800 Subject: [PATCH 55/62] polish code --- examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml | 2 +- examples/qwen/conf/serve/serve_qwen2.5_7b.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml index e48008890..2e4c723dd 100644 --- a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml +++ b/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml @@ -22,7 +22,7 @@ experiment: CUDA_DEVICE_MAX_CONNECTIONS: 1 VLLM_USE_V1: 0 FLAGCX_SOCKET_IFNAME: bond0 - FLAGCX_PATH: /mine/ip122/tune_qwen/FlagCX/ + FLAGCX_PATH: /path/to/FlagCX/ FLAGCX_DEBUG: TRACE FLAGCX_DEBUG_SUBSYS: ALL USE_FLAGCX: true diff --git a/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml b/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml index 49e39288d..b1bd95bfb 100644 --- a/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml +++ b/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml @@ -1,7 +1,7 @@ - serve_id: vllm_model engine: vllm engine_args: - model: /models/Qwen2.5-0.5B-Instruct + model: /models/Qwen2.5-7B-Instruct host: 0.0.0.0 tensor_parallel_size: 1 pipeline_parallel_size: 1 From 058cf508f6d937dd1ae713e6d725bd1cb00c0856 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 16:36:44 +0800 Subject: [PATCH 56/62] fix code --- .../device_communicators/flagcx_wrapper.py | 405 ------------------ .../kv_pipe/flagcx_p2p_nccl_pipe.py | 7 +- 2 files changed, 4 insertions(+), 408 deletions(-) delete mode 100644 flagscale/backends/vllm/vllm/distributed/device_communicators/flagcx_wrapper.py diff --git a/flagscale/backends/vllm/vllm/distributed/device_communicators/flagcx_wrapper.py b/flagscale/backends/vllm/vllm/distributed/device_communicators/flagcx_wrapper.py deleted file mode 100644 index 634db5852..000000000 --- a/flagscale/backends/vllm/vllm/distributed/device_communicators/flagcx_wrapper.py +++ /dev/null @@ -1,405 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# Adopted from https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/pynccl_wrapper.py. Below is the original copyright: -# SPDX-License-Identifier: Apache-2.0 -import ctypes -import platform -from dataclasses import dataclass -from typing import Any, Dict, List, Optional - -import torch -from torch.distributed import ReduceOp - -# === export types and functions from flagcx to Python === -# for the original flagcx definition, please check -# https://github.com/FlagOpen/FlagCX/blob/main/flagcx/include/flagcx.h - -flagcxResult_t = ctypes.c_int -flagcxDataType_t = ctypes.c_int -flagcxRedOp_t = ctypes.c_int -flagcxMemcpyType_t = ctypes.c_int -flagcxMemType_t = ctypes.c_int - -flagcxHandlerGroup_t = ctypes.c_void_p -flagcxComm_t = ctypes.c_void_p -flagcxEvent_t = ctypes.c_void_p -cudaStream_t = ctypes.c_void_p -buffer_type = ctypes.c_void_p - - -class flagcxStream(ctypes.Structure): - _fields_ = [("base", cudaStream_t)] -flagcxStream_t = ctypes.POINTER(flagcxStream) - - -class flagcxUniqueId(ctypes.Structure): - _fields_ = [("internal", ctypes.c_byte * 256)] -flagcxUniqueId_t = ctypes.POINTER(flagcxUniqueId) - - -DEVICE_SYNCHRONIZE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t) -DEVICE_MEMCPY_FUNCTYPE = ctypes.CFUNCTYPE( - flagcxResult_t, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, - flagcxMemcpyType_t, flagcxStream_t -) -DEVICE_MEMSET_FUNCTYPE = ctypes.CFUNCTYPE( - flagcxResult_t, ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t, - flagcxMemType_t, flagcxStream_t -) -DEVICE_MALLOC_FUNCTYPE = ctypes.CFUNCTYPE( - flagcxResult_t, ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t, - flagcxMemType_t, flagcxStream_t -) -DEVICE_FREE_FUNCTYPE = ctypes.CFUNCTYPE( - flagcxResult_t, ctypes.c_void_p, flagcxMemType_t, flagcxStream_t -) -SET_DEVICE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.c_int) -GET_DEVICE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(ctypes.c_int)) -GET_DEVICE_COUNT_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(ctypes.c_int)) -GET_VENDOR_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.c_char_p) - -STREAM_CREATE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(flagcxStream_t)) -STREAM_DESTROY_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxStream_t) -STREAM_COPY_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(flagcxStream_t), ctypes.c_void_p) -STREAM_FREE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxStream_t) -STREAM_SYNCHRONIZE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxStream_t) -STREAM_QUERY_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxStream_t) -STREAM_WAIT_EVENT_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxStream_t, flagcxEvent_t) - -EVENT_CREATE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(flagcxEvent_t)) -EVENT_DESTROY_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxEvent_t) -EVENT_RECORD_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxEvent_t, flagcxStream_t) -EVENT_SYNCHRONIZE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxEvent_t) -EVENT_QUERY_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxEvent_t) -class flagcxDeviceHandle(ctypes.Structure): - _fields_ = [ - # Basic functions - ("deviceSynchronize", DEVICE_SYNCHRONIZE_FUNCTYPE), - ("deviceMemcpy", DEVICE_MEMCPY_FUNCTYPE), - ("deviceMemset", DEVICE_MEMSET_FUNCTYPE), - ("deviceMalloc", DEVICE_MALLOC_FUNCTYPE), - ("deviceFree", DEVICE_FREE_FUNCTYPE), - ("setDevice", SET_DEVICE_FUNCTYPE), - ("getDevice", GET_DEVICE_FUNCTYPE), - ("getDeviceCount", GET_DEVICE_COUNT_FUNCTYPE), - ("getVendor", GET_VENDOR_FUNCTYPE), - # Stream functions - ("streamCreate", STREAM_CREATE_FUNCTYPE), - ("streamDestroy", STREAM_DESTROY_FUNCTYPE), - ("streamCopy", STREAM_COPY_FUNCTYPE), - ("streamFree", STREAM_FREE_FUNCTYPE), - ("streamSynchronize", STREAM_SYNCHRONIZE_FUNCTYPE), - ("streamQuery", STREAM_QUERY_FUNCTYPE), - ("streamWaitEvent", STREAM_WAIT_EVENT_FUNCTYPE), - # Event functions - ("eventCreate", EVENT_CREATE_FUNCTYPE), - ("eventDestroy", EVENT_DESTROY_FUNCTYPE), - ("eventRecord", EVENT_RECORD_FUNCTYPE), - ("eventSynchronize", EVENT_SYNCHRONIZE_FUNCTYPE), - ("eventQuery", EVENT_QUERY_FUNCTYPE), - ] -flagcxDeviceHandle_t = ctypes.POINTER(flagcxDeviceHandle) - -class flagcxHandlerGroup(ctypes.Structure): - _fields_ = [ - ("uniqueId", flagcxUniqueId_t), - ("comm", flagcxComm_t), - ("devHandle", flagcxDeviceHandle_t), - ] -flagcxHandlerGroup_t = ctypes.POINTER(flagcxHandlerGroup) - - -class flagcxDataTypeEnum: - flagcxInt8 = 0 - flagcxChar = 0 - flagcxUint8 = 1 - flagcxInt32 = 2 - flagcxInt = 2 - flagcxUint32 = 3 - flagcxInt64 = 4 - flagcxUint64 = 5 - flagcxFloat16 = 6 - flagcxHalf = 6 - flagcxFloat32 = 7 - flagcxFloat = 7 - flagcxFloat64 = 8 - flagcxDouble = 8 - flagcxBfloat16 = 9 - flagcxNumTypes = 10 - - @classmethod - def from_torch(cls, dtype: torch.dtype) -> int: - if dtype == torch.int8: - return cls.flagcxInt8 - if dtype == torch.uint8: - return cls.flagcxUint8 - if dtype == torch.int32: - return cls.flagcxInt32 - if dtype == torch.int64: - return cls.flagcxInt64 - if dtype == torch.float16: - return cls.flagcxFloat16 - if dtype == torch.float32: - return cls.flagcxFloat32 - if dtype == torch.float64: - return cls.flagcxFloat64 - if dtype == torch.bfloat16: - return cls.flagcxBfloat16 - raise ValueError(f"Unsupported dtype: {dtype}") - - -class flagcxRedOpTypeEnum: - flagcxSum = 0 - flagcxProd = 1 - flagcxMax = 2 - flagcxMin = 3 - flagcxAvg = 4 - flagcxNumOps = 5 - - @classmethod - def from_torch(cls, op: ReduceOp) -> int: - if op == ReduceOp.SUM: - return cls.flagcxSum - if op == ReduceOp.PRODUCT: - return cls.flagcxProd - if op == ReduceOp.MAX: - return cls.flagcxMax - if op == ReduceOp.MIN: - return cls.flagcxMin - if op == ReduceOp.AVG: - return cls.flagcxAvg - raise ValueError(f"Unsupported op: {op}") - - -@dataclass -class Function: - name: str - restype: Any - argtypes: List[Any] - - -class FLAGCXLibrary: - exported_functions = [ - Function("flagcxHandleInit", flagcxResult_t, - [ctypes.POINTER(flagcxHandlerGroup_t)]), - Function("flagcxHandleFree", flagcxResult_t, - [flagcxHandlerGroup_t]), - Function("flagcxGetErrorString", ctypes.c_char_p, [flagcxResult_t]), - Function("flagcxGetVersion", flagcxResult_t, - [ctypes.POINTER(ctypes.c_int)]), - Function("flagcxGetUniqueId", flagcxResult_t, - [ctypes.POINTER(ctypes.POINTER(flagcxUniqueId))]), - # Note that flagcxComm_t is a pointer type, so the first argument - # is a pointer to a pointer - Function("flagcxCommInitRank", flagcxResult_t, [ - ctypes.POINTER(flagcxComm_t), ctypes.c_int, ctypes.POINTER(flagcxUniqueId), - ctypes.c_int - ]), - # Note that flagcxStream_t is a pointer type, so the last argument - # is a pointer - Function("flagcxAllReduce", flagcxResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, flagcxDataType_t, - flagcxRedOp_t, flagcxComm_t, flagcxStream_t - ]), - - # Note that flagcxStream_t is a pointer type, so the last argument - # is a pointer - Function("flagcxAllGather", flagcxResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, flagcxDataType_t, - flagcxComm_t, flagcxStream_t - ]), - - # Note that flagcxStream_t is a pointer type, so the last argument - # is a pointer - Function("flagcxReduceScatter", flagcxResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, flagcxDataType_t, - flagcxRedOp_t, flagcxComm_t, flagcxStream_t - ]), - - Function("flagcxSend", flagcxResult_t, [ - buffer_type, ctypes.c_size_t, flagcxDataType_t, ctypes.c_int, - flagcxComm_t, flagcxStream_t - ]), - - Function("flagcxRecv", flagcxResult_t, [ - buffer_type, ctypes.c_size_t, flagcxDataType_t, ctypes.c_int, - flagcxComm_t, flagcxStream_t - ]), - - Function("flagcxBroadcast", flagcxResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, flagcxDataType_t, - ctypes.c_int, flagcxComm_t, flagcxStream_t - ]), - - # be cautious! this is a collective call, it will block until all - # processes in the communicator have called this function. - # because Python object destruction can happen in random order, - # it is better not to call it at all. - # flagcxResult_t flagcxCommDestroy(flagcxComm_t comm); - Function("flagcxCommDestroy", flagcxResult_t, [flagcxComm_t]), - ] - - # class attribute to store the mapping from the path to the library - # to avoid loading the same library multiple times - path_to_library_cache: Dict[str, Any] = {} - - # class attribute to store the mapping from library path - # to the corresponding dictionary - path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} - - def __init__(self, so_file: Optional[str] = None): - - - try: - if so_file not in FLAGCXLibrary.path_to_dict_mapping: - lib = ctypes.CDLL(so_file) - FLAGCXLibrary.path_to_library_cache[so_file] = lib - self.lib = FLAGCXLibrary.path_to_library_cache[so_file] - except Exception as e: - raise e - - if so_file not in FLAGCXLibrary.path_to_dict_mapping: - _funcs: Dict[str, Any] = {} - for func in FLAGCXLibrary.exported_functions: - f = getattr(self.lib, func.name) - f.restype = func.restype - f.argtypes = func.argtypes - _funcs[func.name] = f - FLAGCXLibrary.path_to_dict_mapping[so_file] = _funcs - self._funcs = FLAGCXLibrary.path_to_dict_mapping[so_file] - - # init flagcx handler to call device-related apis - self.handler = flagcxHandlerGroup_t() - self.FLAGCX_CHECK(self._funcs["flagcxHandleInit"](ctypes.byref(self.handler))) - - def __del__(self): - # free flagcx handler - self.FLAGCX_CHECK(self._funcs["flagcxHandleFree"](self.handler)) - - def flagcxGetErrorString(self, result: flagcxResult_t) -> str: - return self._funcs["flagcxGetErrorString"](result).decode("utf-8") - - def FLAGCX_CHECK(self, result: flagcxResult_t) -> None: - if result != 0: - error_str = self.flagcxGetErrorString(result) - raise RuntimeError(f"FLAGCX error: {error_str}") - - def flagcxGetVersion(self) -> str: - version = ctypes.c_int() - self.FLAGCX_CHECK(self._funcs["flagcxGetVersion"](ctypes.byref(version))) - version_str = str(version.value) - # something like 21903 --> "2.19.3" - major = version_str[0].lstrip("0") - minor = version_str[1:3].lstrip("0") - patch = version_str[3:].lstrip("0") - return f"{major}.{minor}.{patch}" - - def flagcxGetUniqueId(self) -> flagcxUniqueId: - unique_id = ctypes.POINTER(flagcxUniqueId)() - self.FLAGCX_CHECK(self._funcs["flagcxGetUniqueId"]( - ctypes.byref(unique_id))) - return unique_id - - def unique_id_from_bytes(self, data: bytes) -> flagcxUniqueId: - """ - Reconstructs an `ncclUniqueId` object from bytes data. - Args: - data: Must be a 128-byte data block (matching NCCL's unique_id). - Returns: - ncclUniqueId: The reconstructed NCCL Unique ID object. - Raises: - ValueError: If the input data length is not 128 bytes. - """ - if len(data) != 256: - raise ValueError( - f"Expected 256 bytes for ncclUniqueId, got {len(data)} bytes") - - unique_id = flagcxUniqueId() - ctypes.memmove(ctypes.addressof(unique_id.internal), data, 256) - return unique_id - - def flagcxCommInitRank(self, world_size: int, unique_id: flagcxUniqueId, - rank: int) -> flagcxComm_t: - comm = flagcxComm_t() - self.FLAGCX_CHECK(self._funcs["flagcxCommInitRank"](ctypes.byref(comm), - world_size, unique_id, - rank)) - return comm - - def flagcxAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, comm: flagcxComm_t, - stream: flagcxStream_t) -> None: - # `datatype` actually should be `flagcxDataType_t` - # and `op` should be `flagcxRedOp_t` - # both are aliases of `ctypes.c_int` - # when we pass int to a function, it will be converted to `ctypes.c_int` - # by ctypes automatically - self.FLAGCX_CHECK(self._funcs["flagcxAllReduce"](sendbuff, recvbuff, count, - datatype, op, comm, - stream)) - - def flagcxReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, comm: flagcxComm_t, - stream: flagcxStream_t) -> None: - # `datatype` actually should be `flagcxDataType_t` - # and `op` should be `flagcxRedOp_t` - # both are aliases of `ctypes.c_int` - # when we pass int to a function, it will be converted to `ctypes.c_int` - # by ctypes automatically - self.FLAGCX_CHECK(self._funcs["flagcxReduceScatter"](sendbuff, recvbuff, - count, datatype, op, - comm, stream)) - - def flagcxAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, comm: flagcxComm_t, - stream: flagcxStream_t) -> None: - # `datatype` actually should be `flagcxDataType_t` - # which is an aliases of `ctypes.c_int` - # when we pass int to a function, it will be converted to `ctypes.c_int` - # by ctypes automatically - self.FLAGCX_CHECK(self._funcs["flagcxAllGather"](sendbuff, recvbuff, count, - datatype, comm, stream)) - - def flagcxSend(self, sendbuff: buffer_type, count: int, datatype: int, - dest: int, comm: flagcxComm_t, stream: flagcxStream_t) -> None: - self.FLAGCX_CHECK(self._funcs["flagcxSend"](sendbuff, count, datatype, - dest, comm, stream)) - - def flagcxRecv(self, recvbuff: buffer_type, count: int, datatype: int, - src: int, comm: flagcxComm_t, stream: flagcxStream_t) -> None: - self.FLAGCX_CHECK(self._funcs["flagcxRecv"](recvbuff, count, datatype, src, - comm, stream)) - - def flagcxBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, root: int, comm: flagcxComm_t, - stream: flagcxStream_t) -> None: - self.FLAGCX_CHECK(self._funcs["flagcxBroadcast"](sendbuff, recvbuff, count, - datatype, root, comm, - stream)) - - def flagcxCommDestroy(self, comm: flagcxComm_t) -> None: - self.FLAGCX_CHECK(self._funcs["flagcxCommDestroy"](comm)) - - def adaptor_stream_create(self): - new_stream = flagcxStream_t() - self.FLAGCX_CHECK(self.handler.contents.devHandle.contents.streamCreate(ctypes.byref(new_stream))) - return new_stream - - def adaptor_stream_copy(self, old_stream): - new_stream = flagcxStream_t() - self.FLAGCX_CHECK(self.handler.contents.devHandle.contents.streamCopy(ctypes.byref(new_stream), ctypes.byref(cudaStream_t(old_stream.cuda_stream)))) - return new_stream - - def adaptor_stream_free(self, stream): - self.FLAGCX_CHECK(self.handler.contents.devHandle.contents.streamFree(stream)) - - def adaptor_stream_destroy(self, stream): - self.FLAGCX_CHECK(self.handler.contents.devHandle.contents.streamDestroy(stream)) - - def sync_stream(self, stream): - self.FLAGCX_CHECK(self.handler.contents.devHandle.contents.streamSynchronize(stream)) - - -__all__ = [ - "FLAGCXLibrary", "flagcxDataTypeEnum", "flagcxRedOpTypeEnum", "flagcxUniqueId", - "flagcxHandlerGroup_t", "flagcxComm_t", "flagcxStream_t", "flagcxEvent_t", "buffer_type", "cudaStream_t" -] \ No newline at end of file diff --git a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py index b407b54df..48b9fa2df 100644 --- a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py +++ b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py @@ -13,15 +13,16 @@ import torch import zmq import ctypes - -from vllm.config import KVTransferConfig -from vllm.distributed.device_communicators.flagcx_wrapper import ( +import sys +sys.path.append(os.getenv('FLAGCX_PATH')) +from plugin.inter_service.flagcx_wrapper import ( FLAGCXLibrary, buffer_type, cudaStream_t, flagcxComm_t, flagcxDataTypeEnum, ) +from vllm.config import KVTransferConfig from vllm.utils import current_stream, get_ip logger = logging.getLogger(__name__) From 23a4318432b7efe5f450454a90ef693284318626 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 16:52:09 +0800 Subject: [PATCH 57/62] fix code --- flagscale/runner/runner_serve.py | 6 +++--- ..._pd_disagg_router.py => run_pd_disaggregation_router.py} | 0 2 files changed, 3 insertions(+), 3 deletions(-) rename flagscale/serve/{run_pd_disagg_router.py => run_pd_disaggregation_router.py} (100%) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 2b29b18aa..1a06b193f 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -363,7 +363,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi f.write("pkill -f 'run_inference_engine'\n") f.write("pkill -f 'run_fs_serve_vllm'\n") f.write("pkill -f 'vllm serve'\n") - f.write("pkill -f 'run_pd_disagg_router'\n") + f.write("pkill -f 'run_pd_disaggregation_router'\n") f.write(f"mkdir -p {default_log_dir}\n") f.write(f"\n") @@ -669,9 +669,9 @@ def _prepare(self): self.user_args = _get_args_vllm(self.config) self.user_envs = self.config.experiment.get("envs", {}) entrypoint = self.config.experiment.task.get("entrypoint", None) - if self.inference_engine: # pd_disagg_router + if self.inference_engine: if self.config.experiment.get("deploy", {}).get("prefill_decode_disaggregation", False): - self.user_script = "flagscale/serve/run_pd_disagg_router.py" + self.user_script = "flagscale/serve/run_pd_disaggregation_router.py" elif not self.use_fs_serve: self.user_script = "flagscale/serve/run_inference_engine.py" else: diff --git a/flagscale/serve/run_pd_disagg_router.py b/flagscale/serve/run_pd_disaggregation_router.py similarity index 100% rename from flagscale/serve/run_pd_disagg_router.py rename to flagscale/serve/run_pd_disaggregation_router.py From 38d2b73544aed3c0733a12522dc660583b2b7a34 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 16:55:46 +0800 Subject: [PATCH 58/62] polish code --- .../distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py index 48b9fa2df..3ceec7936 100644 --- a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py +++ b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py @@ -15,7 +15,7 @@ import ctypes import sys sys.path.append(os.getenv('FLAGCX_PATH')) -from plugin.inter_service.flagcx_wrapper import ( +from plugin.interservice.flagcx_wrapper import ( FLAGCXLibrary, buffer_type, cudaStream_t, From 12a8154de856c221cebdab3e1a8833f9fe536277 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 17:03:39 +0800 Subject: [PATCH 59/62] fix code --- flagscale/runner/runner_serve.py | 126 +------------------------------ flagscale/runner/utils.py | 125 ++++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+), 125 deletions(-) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 1a06b193f..64a7998c9 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -13,6 +13,7 @@ from flagscale.runner.runner_base import JobStatus, RunnerBase from flagscale.runner.utils import ( + ResourceManager, benchmark, dummy_random_input, flatten_dict_to_args, @@ -34,131 +35,6 @@ def _get_multiple_free_ports(num=1, exclude_ports=[]): return allocated_ports -class ResourceManager: - def __init__(self, nodes): - """ - Initialize the ResourceManager with a list of nodes. - Each element in the list should be a two-item list: - - The first item is the node address (a string). - - The second item is a dictionary containing at least the key "slots". - If "type" is not provided, it defaults to "gpu" with a warning. - The first node is treated as the master node, and the rest are worker nodes. - """ - self.nodes = self._initialize_nodes(nodes) - - def _initialize_nodes(self, nodes): - """ - Convert the input nodes list into the internal nodes representation. - Each node is converted into a dictionary with keys: - "address", "slots", "type", and "used" (initialized to 0). - If the "type" is not provided in a node, default it to "gpu" and issue a warning. - """ - initialized_nodes = [] - for node in nodes: - if len(node) != 2: - raise ValueError("Each node must include an address and node data") - address, info = node - if "slots" not in info: - raise ValueError("Node data must contain 'slots'") - if "type" not in info: - logger.warning( - f"Node {address} does not provide a resource type. Defaulting to 'gpu'." - ) - resource_type = info.get("type", "gpu") - initialized_nodes.append( - { - "address": address, - "slots": info["slots"], - "type": resource_type, - "used": 0, # Initialize used slot count to 0 - } - ) - return initialized_nodes - - def get_whole_card_num(self, resource_type="gpu"): - """ - Return the total number of slots across all nodes with the specified resource type. - The return type is int. - """ - total = 0 - for node in self.nodes: - if node["type"] == resource_type: - total += node["slots"] - return total - - def get_available_card_num(self, resource_type="gpu"): - """ - Return the total number of available slots (slots minus used) across all nodes with the specified resource type. - The return type is int. - """ - total = 0 - for node in self.nodes: - if node["type"] == resource_type: - total += node["slots"] - node["used"] - return total - - def get_available_card_ids(self, resource_type="gpu", address="auto", num=1): - """ - Allocate 'num' resource cards from a node and return a list of card indices. - - For the default case (address="auto"), traverse nodes in order: master node first, then worker nodes. - - If a node's available slots (slots - used) are >= num, allocate num consecutive indices (based on the current used value) - and update the node's used count, returning the allocated indices (0-indexed) as a list. - - If the available slots are insufficient at a particular node and address is "auto", continue searching through other nodes. - - If an explicit address is provided, check only that node; if it doesn't exist or lacks sufficient available slots, raise an error. - - If none of the nodes can satisfy the request, raise an error indicating insufficient resources. - """ - # Check the specified node if address is not "auto" - if address != "auto": - node_found = None - for node in self.nodes: - if node["address"] == address and node["type"] == resource_type: - node_found = node - break - if node_found is None: - raise ValueError(f"Node {address} does not exist or resource type mismatch") - free = node_found["slots"] - node_found["used"] - if free < num: - raise ValueError("Insufficient resources") - allocated_ids = list(range(node_found["used"], node_found["used"] + num)) - node_found["used"] += num - return allocated_ids - - # For address == "auto", traverse all nodes (master node first, then worker nodes) - for node in self.nodes: - if node["type"] == resource_type: - free = node["slots"] - node["used"] - if free >= num: - allocated_ids = list(range(node["used"], node["used"] + num)) - node["used"] += num - return allocated_ids - - # If no node satisfies the allocation request, raise an error. - resource_status = self.get_status() - raise ValueError( - f"Require number {num} of resource_type {resource_type} But there is insufficient resources: \n{resource_status}" - ) - - def get_status(self): - """ - Return the status of all nodes as a dictionary. - Each key in the returned dictionary is the node's address, and its value is a dictionary with: - - type: the resource type. - - slots: the total number of slots. - - used: the number of allocated slots. - - available: the number of available slots (slots - used). - """ - status = {} - for node in self.nodes: - status[node["address"]] = { - "type": node["type"], - "slots": node["slots"], - "used": node["used"], - "available": node["slots"] - node["used"], - } - return status - - def _get_args_vllm(config: DictConfig): # see the following link for more details # https://github.com/facebookresearch/hydra/discussions/2750 diff --git a/flagscale/runner/utils.py b/flagscale/runner/utils.py index b03dc8ffa..ae249e3e9 100644 --- a/flagscale/runner/utils.py +++ b/flagscale/runner/utils.py @@ -568,3 +568,128 @@ def process_one_metric( print("=" * 50) return result + + +class ResourceManager: + def __init__(self, nodes): + """ + Initialize the ResourceManager with a list of nodes. + Each element in the list should be a two-item list: + - The first item is the node address (a string). + - The second item is a dictionary containing at least the key "slots". + If "type" is not provided, it defaults to "gpu" with a warning. + The first node is treated as the master node, and the rest are worker nodes. + """ + self.nodes = self._initialize_nodes(nodes) + + def _initialize_nodes(self, nodes): + """ + Convert the input nodes list into the internal nodes representation. + Each node is converted into a dictionary with keys: + "address", "slots", "type", and "used" (initialized to 0). + If the "type" is not provided in a node, default it to "gpu" and issue a warning. + """ + initialized_nodes = [] + for node in nodes: + if len(node) != 2: + raise ValueError("Each node must include an address and node data") + address, info = node + if "slots" not in info: + raise ValueError("Node data must contain 'slots'") + if "type" not in info: + logger.warning( + f"Node {address} does not provide a resource type. Defaulting to 'gpu'." + ) + resource_type = info.get("type", "gpu") + initialized_nodes.append( + { + "address": address, + "slots": info["slots"], + "type": resource_type, + "used": 0, # Initialize used slot count to 0 + } + ) + return initialized_nodes + + def get_whole_card_num(self, resource_type="gpu"): + """ + Return the total number of slots across all nodes with the specified resource type. + The return type is int. + """ + total = 0 + for node in self.nodes: + if node["type"] == resource_type: + total += node["slots"] + return total + + def get_available_card_num(self, resource_type="gpu"): + """ + Return the total number of available slots (slots minus used) across all nodes with the specified resource type. + The return type is int. + """ + total = 0 + for node in self.nodes: + if node["type"] == resource_type: + total += node["slots"] - node["used"] + return total + + def get_available_card_ids(self, resource_type="gpu", address="auto", num=1): + """ + Allocate 'num' resource cards from a node and return a list of card indices. + + For the default case (address="auto"), traverse nodes in order: master node first, then worker nodes. + - If a node's available slots (slots - used) are >= num, allocate num consecutive indices (based on the current used value) + and update the node's used count, returning the allocated indices (0-indexed) as a list. + - If the available slots are insufficient at a particular node and address is "auto", continue searching through other nodes. + - If an explicit address is provided, check only that node; if it doesn't exist or lacks sufficient available slots, raise an error. + - If none of the nodes can satisfy the request, raise an error indicating insufficient resources. + """ + # Check the specified node if address is not "auto" + if address != "auto": + node_found = None + for node in self.nodes: + if node["address"] == address and node["type"] == resource_type: + node_found = node + break + if node_found is None: + raise ValueError(f"Node {address} does not exist or resource type mismatch") + free = node_found["slots"] - node_found["used"] + if free < num: + raise ValueError("Insufficient resources") + allocated_ids = list(range(node_found["used"], node_found["used"] + num)) + node_found["used"] += num + return allocated_ids + + # For address == "auto", traverse all nodes (master node first, then worker nodes) + for node in self.nodes: + if node["type"] == resource_type: + free = node["slots"] - node["used"] + if free >= num: + allocated_ids = list(range(node["used"], node["used"] + num)) + node["used"] += num + return allocated_ids + + # If no node satisfies the allocation request, raise an error. + resource_status = self.get_status() + raise ValueError( + f"Require number {num} of resource_type {resource_type} But there is insufficient resources: \n{resource_status}" + ) + + def get_status(self): + """ + Return the status of all nodes as a dictionary. + Each key in the returned dictionary is the node's address, and its value is a dictionary with: + - type: the resource type. + - slots: the total number of slots. + - used: the number of allocated slots. + - available: the number of available slots (slots - used). + """ + status = {} + for node in self.nodes: + status[node["address"]] = { + "type": node["type"], + "slots": node["slots"], + "used": node["used"], + "available": node["slots"] - node["used"], + } + return status From f070e969896a71c0f90be165a67565ace0b37f5f Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 19:31:50 +0800 Subject: [PATCH 60/62] polish code --- third_party/vllm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/vllm b/third_party/vllm index 139e70670..dc1b4a6f1 160000 --- a/third_party/vllm +++ b/third_party/vllm @@ -1 +1 @@ -Subproject commit 139e706706402859fd55381a39dd7182f0537977 +Subproject commit dc1b4a6f1300003ae27f033afbdff5e2683721ce From b061d00fa85ce5a513cbf13d0d33be50586924b3 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 19:38:34 +0800 Subject: [PATCH 61/62] polish code --- flagscale/runner/runner_serve.py | 4 ++-- ..._pd_disaggregation_router.py => run_disagg_xpyd_router.py} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename flagscale/serve/{run_pd_disaggregation_router.py => run_disagg_xpyd_router.py} (100%) diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 64a7998c9..8fce37cba 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -239,7 +239,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi f.write("pkill -f 'run_inference_engine'\n") f.write("pkill -f 'run_fs_serve_vllm'\n") f.write("pkill -f 'vllm serve'\n") - f.write("pkill -f 'run_pd_disaggregation_router'\n") + f.write("pkill -f 'run_disagg_xpyd_router'\n") f.write(f"mkdir -p {default_log_dir}\n") f.write(f"\n") @@ -547,7 +547,7 @@ def _prepare(self): entrypoint = self.config.experiment.task.get("entrypoint", None) if self.inference_engine: if self.config.experiment.get("deploy", {}).get("prefill_decode_disaggregation", False): - self.user_script = "flagscale/serve/run_pd_disaggregation_router.py" + self.user_script = "flagscale/serve/run_disagg_xpyd_router.py" elif not self.use_fs_serve: self.user_script = "flagscale/serve/run_inference_engine.py" else: diff --git a/flagscale/serve/run_pd_disaggregation_router.py b/flagscale/serve/run_disagg_xpyd_router.py similarity index 100% rename from flagscale/serve/run_pd_disaggregation_router.py rename to flagscale/serve/run_disagg_xpyd_router.py From cd70188d08bebd6ae29049716c5d77b9876cb6f1 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 29 Apr 2025 19:52:47 +0800 Subject: [PATCH 62/62] polish name --- ..._pd_disaggregation.yaml => config_qwen2.5_7b_disagg_xpyd.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/qwen/conf/{config_qwen2.5_7b_pd_disaggregation.yaml => config_qwen2.5_7b_disagg_xpyd.yaml} (100%) diff --git a/examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml b/examples/qwen/conf/config_qwen2.5_7b_disagg_xpyd.yaml similarity index 100% rename from examples/qwen/conf/config_qwen2.5_7b_pd_disaggregation.yaml rename to examples/qwen/conf/config_qwen2.5_7b_disagg_xpyd.yaml