Skip to content

Commit 89d96cc

Browse files
authored
Merge branch 'main' into padded-spec
2 parents 9a9ec7f + b4233a2 commit 89d96cc

File tree

81 files changed

+2924
-984
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+2924
-984
lines changed

.github/workflows/multi_node_test.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,15 @@ jobs:
102102
wait $LOG_PID || true
103103
kill $MONITOR_PID || true
104104
105+
- name: Generate summary
106+
if: always()
107+
run: |
108+
if [ -f "/root/.cache/test_summary.md" ]; then
109+
cat /root/.cache/test_summary.md >> "$GITHUB_STEP_SUMMARY"
110+
else
111+
echo "No summary file found." >> "$GITHUB_STEP_SUMMARY"
112+
fi
113+
105114
- name: Post process
106115
if: always()
107116
run: |

docs/source/tutorials/multi_node_pd_disaggregation_mooncake.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,16 @@ Install the relevant dependencies. The installation of Go is not required.
6666

6767
```shell
6868
cd Mooncake
69-
bash dependencies.sh
69+
bash dependencies.sh -y
7070
```
7171

7272
Install mpi
7373

7474
```shell
75-
apt purge mpich libmpich-dev
76-
apt purge openmpi-bin
77-
apt purge openmpi-bin libopenmpi-dev
78-
apt install mpich libmpich-dev
75+
apt purge mpich libmpich-dev -y
76+
apt purge openmpi-bin -y
77+
apt purge openmpi-bin libopenmpi-dev -y
78+
apt install mpich libmpich-dev -y
7979
export CPATH=/usr/lib/aarch64-linux-gnu/mpich/include/:$CPATH
8080
export CPATH=/usr/lib/aarch64-linux-gnu/openmpi/lib:$CPATH
8181
```

examples/disaggregated_prefill_v1/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ vllm serve /models/deepseek_r1_w8a8 \
205205
Run proxy server on the first node:
206206
```shell
207207
cd /vllm-workspace/vllm-ascend/examples/disaggregated_prefill_v1
208-
python toy_proxy_server.py --host 172.19.32.175 --port 1025 --prefiller-hosts 172.19.241.49 --prefiller-port 20002 --decoder-hosts 172.19.123.51 --decoder-ports 20002
208+
python load_balance_proxy_server_example.py --host 172.19.32.175 --port 1025 --prefiller-hosts 172.19.241.49 --prefiller-port 20002 --decoder-hosts 172.19.123.51 --decoder-ports 20002
209209
```
210210

211211
Verification

examples/disaggregated_prefill_v1/gen_ranktable.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
type=str,
2222
required=False,
2323
help="local device ids")
24+
parser.add_argument("--ranktable-path",
25+
type=str,
26+
default="./ranktable.json",
27+
help="output rank table path")
2428
args = parser.parse_args()
2529
local_host = args.local_host
2630
prefill_device_cnt = args.prefill_device_cnt
@@ -130,7 +134,8 @@ def get_cmd_stdout(cmd):
130134
}
131135

132136
if local_rank == '0':
133-
with open("ranktable.json", "w") as f:
137+
os.makedirs(os.path.dirname(args.ranktable_path), exist_ok=True)
138+
with open(args.ranktable_path, "w") as f:
134139
json.dump(ranktable, f, indent=4)
135140

136141
print("gen ranktable.json done")

examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py

Lines changed: 167 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,18 @@
8484
#
8585
# For more details, see the code and comments in this file.
8686

87-
8887
import argparse
8988
import asyncio
9089
import functools
9190
import heapq
91+
import json
9292
import os
9393
import sys
94-
import uuid
9594
import threading
95+
import uuid
9696
from contextlib import asynccontextmanager
97-
from typing import List
97+
from dataclasses import dataclass
98+
from typing import Any, List
9899

99100
import httpx
100101
from fastapi import FastAPI, Request
@@ -106,6 +107,7 @@
106107
# Add uvloop for faster event loop if available
107108
try:
108109
import uvloop
110+
109111
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
110112
except ImportError:
111113
pass
@@ -324,7 +326,7 @@ async def listen_for_disconnect(request: Request) -> None:
324326

325327

326328
def with_cancellation(handler_func):
327-
329+
328330
@functools.wraps(handler_func)
329331
async def wrapper(*args, **kwargs):
330332
request = kwargs["request"]
@@ -337,9 +339,9 @@ async def wrapper(*args, **kwargs):
337339
if handler_task in done:
338340
return handler_task.result()
339341
return None
340-
342+
341343
return wrapper
342-
344+
343345

344346
app = FastAPI(lifespan=lifespan)
345347

@@ -362,7 +364,8 @@ async def send_request_to_service(client: httpx.AsyncClient,
362364
"remote_host": None,
363365
"remote_port": None,
364366
"aborted_request": list(aborted_requests),
365-
"metaserver": f"http://{global_args.host}:{global_args.port}/v1/metaserver"
367+
"metaserver":
368+
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
366369
}
367370
req_data["stream"] = False
368371
req_data["max_tokens"] = 1
@@ -455,72 +458,174 @@ def get_api_request_id(api, req_id):
455458
return "chatcmpl-" + req_id
456459

457460

461+
async def _handle_select_instance(api: str, req_data: Any,
462+
request_length: int):
463+
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
464+
logger.debug(
465+
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
466+
)
467+
request_id = await proxy_state.next_req_id()
468+
# Select prefiller
469+
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
470+
prefiller = proxy_state.prefillers[prefiller_idx]
471+
result_future = asyncio.Future() # type: ignore
472+
request_id_api = get_api_request_id(api, request_id)
473+
proxy_state.req_id_future[request_id_api] = result_future
474+
# Send request to prefiller
475+
asyncio.get_running_loop().create_task(
476+
send_request_to_service(prefiller.client,
477+
prefiller_idx,
478+
api,
479+
req_data,
480+
request_id,
481+
max_retries=global_args.max_retries,
482+
base_delay=global_args.retry_delay))
483+
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
484+
485+
response = await result_future
486+
del proxy_state.req_id_future[request_id_api]
487+
req_data["kv_transfer_params"] = response
488+
489+
# Select decoder
490+
decoder_score = proxy_state.calculate_decode_scores(request_length)
491+
logger.debug("Decoder score: %f", decoder_score)
492+
# Use the prefiller's kv_transfer_params to select decoder
493+
decoder_idx = proxy_state.select_decoder(decoder_score)
494+
decoder = proxy_state.decoders[decoder_idx]
495+
logger.debug("Using %s %s", prefiller.url, decoder.url)
496+
return InstanceInfo(request_id=request_id,
497+
prefiller_idx=prefiller_idx,
498+
prefiller_score=prefiller_score,
499+
prefiller=prefiller,
500+
decoder=decoder,
501+
decoder_idx=decoder_idx,
502+
decoder_score=decoder_score)
503+
504+
505+
@dataclass
506+
class InstanceInfo:
507+
request_id: str
508+
prefiller_idx: int
509+
prefiller_score: float
510+
prefiller: ServerState
511+
decoder_idx: int
512+
decoder_score: float
513+
decoder: ServerState
514+
515+
458516
async def _handle_completions(api: str, request: Request):
459517
try:
460518
req_data = await request.json()
461519
req_body = await request.body()
462520
request_length = len(req_body)
463-
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
464-
logger.debug(
465-
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
466-
)
467-
request_id = await proxy_state.next_req_id()
468-
# Select prefiller
469-
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
470-
prefiller = proxy_state.prefillers[prefiller_idx]
471-
result_future = asyncio.Future() # type: ignore
472-
request_id_api = get_api_request_id(api, request_id)
473-
proxy_state.req_id_future[request_id_api] = result_future
474-
# Send request to prefiller
475-
asyncio.get_running_loop().create_task(send_request_to_service(
476-
prefiller.client,
477-
prefiller_idx,
478-
api,
479-
req_data,
480-
request_id,
481-
max_retries=global_args.max_retries,
482-
base_delay=global_args.retry_delay))
483-
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
484-
485-
response = await result_future
486-
del proxy_state.req_id_future[request_id_api]
487-
req_data["kv_transfer_params"] = response
488-
489-
# Select decoder
490-
decoder_score = proxy_state.calculate_decode_scores(request_length)
491-
logger.debug("Decoder score: %f", decoder_score)
492-
# Use the prefiller's kv_transfer_params to select decoder
493-
decoder_idx = proxy_state.select_decoder(decoder_score)
494-
decoder = proxy_state.decoders[decoder_idx]
495-
logger.debug("Using %s %s", prefiller.url, decoder.url)
496-
# Stream response from decoder
497-
released_kv = False
521+
instance_info = await _handle_select_instance(api, req_data,
522+
request_length)
523+
stream_flag = bool(req_data.get("stream", False))
524+
chat_flag = "messages" in req_data
525+
526+
if "prompt" in req_data:
527+
origin_prompt = req_data["prompt"]
528+
elif chat_flag:
529+
messages = req_data["messages"]
530+
origin_prompt = messages[0].get("content", "")
531+
else:
532+
origin_prompt = ""
533+
# refer to vLLM sampling_params: max_token default value
534+
origin_max_tokens = req_data.get("max_tokens", 16)
535+
498536
async def generate_stream():
499-
nonlocal released_kv
537+
nonlocal instance_info
538+
generated_token = ""
539+
released_kv = False
540+
retry_count = 0
541+
retry = True
542+
completion_tokens = 0
500543
# Only one await per chunk, minimal logic in loop
501544
try:
502-
async for chunk in stream_service_response_with_retry(
503-
decoder.client,
504-
api,
505-
req_data,
506-
request_id=request_id,
507-
max_retries=global_args.max_retries,
508-
base_delay=global_args.retry_delay):
509-
if not released_kv and chunk:
510-
proxy_state.release_prefiller_kv(
511-
prefiller_idx, prefiller_score)
512-
released_kv = True
513-
yield chunk
545+
while retry:
546+
retry = False
547+
async for chunk in stream_service_response_with_retry(
548+
instance_info.decoder.client,
549+
api,
550+
req_data,
551+
request_id=instance_info.request_id,
552+
max_retries=global_args.max_retries,
553+
base_delay=global_args.retry_delay):
554+
if not released_kv and chunk:
555+
proxy_state.release_prefiller_kv(
556+
instance_info.prefiller_idx,
557+
instance_info.prefiller_score)
558+
released_kv = True
559+
chunk_str = chunk.decode("utf-8").strip()
560+
if not chunk_str:
561+
continue
562+
if chunk_str.startswith("data: "):
563+
chunk_str = chunk_str[len("data: "):]
564+
try:
565+
chunk_json = json.loads(chunk_str)
566+
except json.JSONDecodeError:
567+
# if chunk is [done], skip it.
568+
logger.warning(
569+
f"Skipping chunk: {chunk_str}")
570+
yield chunk
571+
continue
572+
choices = chunk_json.get("choices", [])
573+
if not choices:
574+
yield chunk
575+
continue
576+
577+
choice = choices[0]
578+
delta = choice.get("delta") or {}
579+
message = choice.get("message") or {}
580+
content = (
581+
delta.get("content")
582+
or message.get("content")
583+
or choice.get("text")
584+
or ""
585+
)
586+
generated_token += content
587+
588+
stop_reason = choice.get(
589+
"stop_reason")
590+
usage = chunk_json.get("usage", {})
591+
completion_tokens = (completion_tokens + 1) if stream_flag else \
592+
(completion_tokens + usage.get("completion_tokens"))
593+
if stop_reason == "recomputed":
594+
retry = True
595+
retry_count += 1
596+
if chat_flag:
597+
messages[0][
598+
"content"] = origin_prompt + generated_token
599+
else:
600+
req_data[
601+
"prompt"] = origin_prompt + generated_token
602+
req_data[
603+
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
604+
tmp_request_length = len(
605+
json.dumps(req_data).encode("utf-8"))
606+
instance_info = await _handle_select_instance(
607+
api, req_data, tmp_request_length)
608+
break
609+
if retry_count > 0 and not stream_flag:
610+
if chat_flag:
611+
choices[0]["message"][
612+
"content"] = generated_token
613+
else:
614+
choices[0]["text"] = generated_token
615+
chunk = json.dumps(chunk_json).encode("utf-8")
616+
yield chunk
514617
except Exception as e:
515618
logger.error(
516-
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
619+
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
517620
)
518-
proxy_state.abort_prefiller_request(prefiller_idx, request_id)
519-
proxy_state.release_prefiller_kv(prefiller_idx,
520-
prefiller_score)
621+
proxy_state.abort_prefiller_request(
622+
instance_info.prefiller_idx, instance_info.request_id)
623+
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
624+
instance_info.prefiller_score)
521625

522626
# After streaming done, release tokens
523-
proxy_state.release_decoder(decoder_idx, decoder_score)
627+
proxy_state.release_decoder(instance_info.decoder_idx,
628+
instance_info.decoder_score)
524629

525630
return StreamingResponse(generate_stream(),
526631
media_type="application/json")
@@ -564,13 +669,12 @@ async def metaserver(request: Request):
564669
result_future = proxy_state.req_id_future[request_id]
565670
result_future.set_result(req_data)
566671
except Exception as e:
567-
logger.error(
568-
f"Post metaserver failed with: {str(e)}"
569-
)
672+
logger.error(f"Post metaserver failed with: {str(e)}")
570673

571674

572675
if __name__ == '__main__':
573676
global global_args
574677
global_args = parse_args()
575678
import uvicorn
679+
576680
uvicorn.run(app, host=global_args.host, port=global_args.port)

0 commit comments

Comments
 (0)