Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions examples/disaggregated-prefill-v1/disagg_prefill_proxy_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0

import json
import os

import aiohttp
from quart import Quart, make_response, request

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)

app = Quart(__name__)

PREFILL_ENDPOINT = "localhost:8100"
DECODE_ENDPOINT = "localhost:8200"


async def forward_request(url, data, headers: dict):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers.update({
"Authorization":
f"Bearer {os.environ.get('OPENAI_API_KEY')}",
})

async with session.post(url=url, json=data,
headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes


@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
original_request_data = await request.get_json()
print(f"{request.headers.get('X-Request-ID')=}")

prefill_request = original_request_data.copy()
# Change max_tokens = 1 to let it only do prefill
prefill_request["max_tokens"] = 1

# Finish prefill
async for prefill_result in forward_request(
f"http://{PREFILL_ENDPOINT}/v1/completions",
prefill_request,
headers={
"X-Request-ID": request.headers.get("X-Request-ID"),
},
):
# Print the prefill result
print(f"===== Prefill result =====")

Check failure on line 50 in examples/disaggregated-prefill-v1/disagg_prefill_proxy_server.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F541)

examples/disaggregated-prefill-v1/disagg_prefill_proxy_server.py:50:19: F541 f-string without any placeholders
print(prefill_result.decode("utf-8"))
print("==========================")
response = json.loads(prefill_result.decode("utf-8"))
continue

# Get the prefill result token, and add it to the decoding request
decode_request = original_request_data.copy()
for idx, choices in enumerate(response.get("choices")):
decode_request["prompt"][idx] += choices.get("text")

# Return the decoding result
generator = forward_request(
f"http://{DECODE_ENDPOINT}/v1/completions",
decode_request,
headers={
"X-Request-ID": request.headers.get("X-Request-ID"),
},
)
response = await make_response(generator)
response.timeout = None

return response

except Exception as e:
import sys
import traceback

exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))


if __name__ == "__main__":
app.run(port=8000)
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#!/bin/bash
# This file demonstrates the example usage of disaggregated prefilling We will
# launch 2 vllm instances (1 for prefill and 1 for decode), and then transfer
# the KV cache between them.

set -xe

current_dir=$(dirname "$0")

# vLLM Environment configuration
export VLLM_USE_V1=1

# vLLM-Ascend Environment configuration
export GLOBAL_RANKTABLE="${current_dir}/global_ranktable.json"
# The following environment variables are required for LLMDataDist.
export PROMPT_DEVICE_ID=0,1,2,3
export DECODE_DEVICE_ID=4,5,6,7
export TENSOR_PARALLEL_SIZE=$(($(echo $PROMPT_DEVICE_ID | grep -o ',' | wc -l) + 1))

# Model Configuration
export MODEL_NAME="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

# Generate the global rank table
if [ ! -f "${GLOBAL_RANKTABLE}" ]; then
echo "Generating global rank table..."
# TODO(jianzs): Impl a tool to generate the global rank table automatically
else
echo "Global rank table already exists."
fi

echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧"
sleep 1

# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'cleanup' INT

# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo "Cleanup complete. Exiting."
exit 0
}

# install quart first -- required for disagg prefill proxy serve
if python3 -c "import quart" &>/dev/null; then
echo "Quart is already installed."
else
echo "Quart is not installed. Installing..."
python3 -m pip install quart
fi

# a function that waits vLLM server to start
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}

ASCEND_RT_VISIBLE_DEVICES=${PROMPT_DEVICE_ID} vllm serve ${MODEL_NAME} \
--port 8100 \
--max-model-len 100 \
--gpu-memory-utilization 0.9 \
--trust-remote-code \
--enforce-eager \
--no-enable-prefix-caching \
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
--kv-transfer-config \
'{
"kv_connector": "AscendHcclConnectorV1",
"kv_buffer_device": "npu",
"kv_role": "kv_producer",
"kv_rank": 0,
"kv_parallel_size": 2,
"kv_connector_extra_config": {
"local_server_id": "server-0"
}
}' &

ASCEND_RT_VISIBLE_DEVICES=${DECODE_DEVICE_ID} vllm serve ${MODEL_NAME} \
--port 8200 \
--max-model-len 100 \
--gpu-memory-utilization 0.9 \
--trust-remote-code \
--enforce-eager \
--no-enable-prefix-caching \
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
--kv-transfer-config \
'{
"kv_connector": "AscendHcclConnectorV1",
"kv_buffer_device": "npu",
"kv_role": "kv_consumer",
"kv_rank": 1,
"kv_parallel_size": 2,
"kv_connector_extra_config": {
"local_server_id": "server-1"
}
}' &

# wait until prefill and decode instances are ready
wait_for_server 8100
wait_for_server 8200

echo "🚧🚧 Warning: server started 🚧🚧"

python3 disagg_prefill_proxy_server.py
23 changes: 23 additions & 0 deletions examples/disaggregated-prefill-v1/send_request.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/bin/bash

# Make sure the model is same as the one used in the server
MODEL_NAME="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
REQUEST_ID=request$RANDOM

curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-H "X-Request-ID: ${REQUEST_ID}" \
-d '{
"ignore_eos": false,
"stream": false,
"stop": "None",
"temperature": 0.5,
"top_k": -1,
"top_p": 1,
"model": "'${MODEL_NAME}'",
"prompt": [
"In 2020, who won the world series?",
"In 2019, Who won the world series?"
],
"max_tokens": 40
}'
4 changes: 2 additions & 2 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def build(self,
if self._num_decodes > 0:
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions[:self._num_decode_tokens],
block_table=block_table[:self._num_decode_tokens, ...],
seq_lens=seq_lens[:self._num_decode_tokens])
block_table=block_table[:self._num_decodes, ...],
seq_lens=seq_lens[:self._num_decodes])

return self.metadata_cls( # type: ignore
num_actual_tokens=num_actual_tokens,
Expand Down
5 changes: 5 additions & 0 deletions vllm_ascend/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@
KVConnectorFactory.register_connector(
"AscendHcclConnector", "vllm_ascend.distributed.llmdatadist_connector",
"LLMDataDistConnector")

KVConnectorFactory.register_connector(
"AscendHcclConnectorV1",
"vllm_ascend.distributed.llmdatadist_connector_v1",
"LLMDataDistConnectorV1")
Loading
Loading