Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Implement disagg prefill by StatelessProcessGroup #10502

Merged
merged 35 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d72d171
move commit from PR 8498 to here in order to fix DCO
KuntaiDu Nov 20, 2024
1eadc94
Make ruff happy.
KuntaiDu Nov 20, 2024
a36c12c
fix wrong import in tests
KuntaiDu Nov 20, 2024
6768108
adjust docstring
KuntaiDu Nov 20, 2024
83e3589
make yapf happy
KuntaiDu Nov 20, 2024
263677b
make format checker happy
KuntaiDu Nov 20, 2024
47a3f79
add an empty __init__.py file to make sure it is recognized as python…
KuntaiDu Nov 21, 2024
e84e14a
bug fix: handle the case where kv_transfer_config is None
KuntaiDu Nov 21, 2024
0243d71
further fix the case when config is None
KuntaiDu Nov 21, 2024
fdd605a
make yapf happy
KuntaiDu Nov 21, 2024
5c27cb7
Merge branch 'main' into kuntai-disagg-fix-DCO
KuntaiDu Nov 22, 2024
509c60e
resolve merge conflict
KuntaiDu Nov 24, 2024
f6ca0bb
make format checker happy
KuntaiDu Nov 24, 2024
2fcb099
adjust docstring
KuntaiDu Nov 28, 2024
68ef930
align default value for kv-buffer-device
KuntaiDu Nov 28, 2024
c0d8d22
Adjust implementation for CPU send recv in statelessprocessgroup
KuntaiDu Nov 28, 2024
b9315ec
send VllmConfig instead of KVTransferConfig to connector so that it c…
KuntaiDu Nov 28, 2024
dda64b9
add -X POST flag to curl -- thanks @coolkp
KuntaiDu Nov 28, 2024
25aab25
lowering the max model len, to make sure the example is runnable on l…
KuntaiDu Nov 28, 2024
87eaea2
make format checker happy
KuntaiDu Nov 28, 2024
6b0c370
Merge kv transfer config into a single CLI, and people can add value …
KuntaiDu Nov 29, 2024
11bd115
clean up diff
KuntaiDu Nov 29, 2024
bef8171
Syntax adjustments for disagg prefill example.
KuntaiDu Nov 29, 2024
f826c44
Adjust test---we now pass '--kv-transfer-config' instead of multiple …
KuntaiDu Nov 29, 2024
2cd7ee5
Remove send and recv from StatelessProcessGroup, add comments to remi…
KuntaiDu Dec 1, 2024
6b7901e
Make format checker happy
KuntaiDu Dec 1, 2024
0cc51a0
bug fix on PyNcclPipe: it needs to support CPU as a device in order t…
KuntaiDu Dec 1, 2024
6af15f0
Add a warning to notify people that the usage of disaggregated prefil…
KuntaiDu Dec 1, 2024
33973e9
Reduce the # of GPU to 2, in order to save CI cost.
KuntaiDu Dec 1, 2024
a14180f
Refactor --- create three abstractions: kv_pipe, kv_lookup_buffer and…
KuntaiDu Dec 1, 2024
4fe6fcc
Add KVLookupBufferBase abstraction
KuntaiDu Dec 1, 2024
b721b5c
Add README and adjust docstrings
KuntaiDu Dec 1, 2024
8d6916a
Merge disaggregated prefill test into existing 2GPU test.
KuntaiDu Dec 1, 2024
a05676d
Rename config to vllm_config for consistency.
KuntaiDu Dec 1, 2024
c3aa7bb
fix typo in README
KuntaiDu Dec 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,9 @@ steps:
- vllm/model_executor/models/
- tests/distributed/
- vllm/compilation
- vllm/worker/worker_base.py
- vllm/worker/worker.py
- vllm/worker/model_runner.py
commands:
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
Expand All @@ -443,6 +446,7 @@ steps:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py

- label: Multi-step Tests (4 GPUs) # 36min
working_dir: "/vllm-workspace/tests"
Expand Down
144 changes: 144 additions & 0 deletions benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#!/bin/bash

# benchmark the overhead of disaggregated prefill.
# methodology:
# - send all request to prefill vLLM instance. It will buffer KV cache.
# - then send all request to decode instance.
# - The TTFT of decode instance is the overhead.

set -ex

kill_gpu_processes() {
# kill all processes on GPU.
pkill -f pt_main_thread
sleep 10

# remove vllm config file
rm -rf ~/.config/vllm

# Print the GPU memory usage
# so that we know if all GPU processes are killed.
gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0)
# The memory usage should be 0 MB.
echo "GPU 0 Memory Usage: $gpu_memory_usage MB"
}

wait_for_server() {
# wait for vllm server to start
# return 1 if vllm server crashes
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}


benchmark() {

export VLLM_LOGGING_LEVEL=DEBUG
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')

# compare chunked prefill with disaggregated prefill

results_folder="./results"
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
dataset_name="sonnet"
dataset_path="../sonnet_4x.txt"
num_prompts=10
qps=$1
prefix_len=50
input_len=2048
output_len=$2


CUDA_VISIBLE_DEVICES=0 python3 \
-m vllm.entrypoints.openai.api_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 8100 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &


CUDA_VISIBLE_DEVICES=1 python3 \
-m vllm.entrypoints.openai.api_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 8200 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' &

wait_for_server 8100
wait_for_server 8200

# let the prefill instance finish prefill
python3 ../benchmark_serving.py \
--backend vllm \
--model $model \
--dataset-name $dataset_name \
--dataset-path $dataset_path \
--sonnet-input-len $input_len \
--sonnet-output-len "$output_len" \
--sonnet-prefix-len $prefix_len \
--num-prompts $num_prompts \
--port 8100 \
--save-result \
--result-dir $results_folder \
--result-filename disagg_prefill_2xtp4.json \
--request-rate "inf"


# send the request to decode.
# The TTFT of this command will be the overhead of disagg prefill impl.
python3 ../benchmark_serving.py \
--backend vllm \
--model $model \
--dataset-name $dataset_name \
--dataset-path $dataset_path \
--sonnet-input-len $input_len \
--sonnet-output-len "$output_len" \
--sonnet-prefix-len $prefix_len \
--num-prompts $num_prompts \
--port 8200 \
--save-result \
--result-dir $results_folder \
--result-filename disagg_prefill_2xtp4.json \
--request-rate "$qps"
kill_gpu_processes

}


main() {

(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
(which jq) || (apt-get -y install jq)
(which socat) || (apt-get -y install socat)

pip install quart httpx

cd "$(dirname "$0")"

cd ..
# create sonnet-4x.txt
echo "" > sonnet_4x.txt
for _ in {1..4}
do
cat sonnet.txt >> sonnet_4x.txt
done
cd disagg_benchmarks

rm -rf results
mkdir results

default_qps=1
default_output_len=1
benchmark $default_qps $default_output_len

}


main "$@"
164 changes: 164 additions & 0 deletions benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#!/bin/bash

# Requirement: 8x H100 GPUs.


# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV
# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests
# Resource: 8x H100
# Approaches:
# 1. Chunked prefill: 1 vllm instance with tp=8
# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4
# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance
# Prefilling instance: max_output_token=1
# Decoding instance: force the input tokens be the same across requests to bypass prefilling

set -ex

kill_gpu_processes() {
# kill all processes on GPU.
pgrep pt_main_thread | xargs -r kill -9
pgrep python3 | xargs -r kill -9
for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done
sleep 1
}

wait_for_server() {
# wait for vllm server to start
# return 1 if vllm server crashes
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}


launch_chunked_prefill() {
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
# disagg prefill
CUDA_VISIBLE_DEVICES=0 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
--port 8100 \
--max-model-len 10000 \
--enable-chunked-prefill \
--gpu-memory-utilization 0.6 &
CUDA_VISIBLE_DEVICES=1 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
--port 8200 \
--max-model-len 10000 \
--enable-chunked-prefill \
--gpu-memory-utilization 0.6 &
wait_for_server 8100
wait_for_server 8200
python3 round_robin_proxy.py &
sleep 1
}


launch_disagg_prefill() {
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
# disagg prefill
CUDA_VISIBLE_DEVICES=0 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
--port 8100 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &

CUDA_VISIBLE_DEVICES=1 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
--port 8200 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' &

wait_for_server 8100
wait_for_server 8200
python3 disagg_prefill_proxy_server.py &
sleep 1
}


benchmark() {
results_folder="./results"
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
dataset_name="sonnet"
dataset_path="../sonnet_4x.txt"
num_prompts=100
qps=$1
prefix_len=50
input_len=1024
output_len=$2
tag=$3

python3 ../benchmark_serving.py \
--backend vllm \
--model $model \
--dataset-name $dataset_name \
--dataset-path $dataset_path \
--sonnet-input-len $input_len \
--sonnet-output-len "$output_len" \
--sonnet-prefix-len $prefix_len \
--num-prompts $num_prompts \
--port 8000 \
--save-result \
--result-dir $results_folder \
--result-filename "$tag"-qps-"$qps".json \
--request-rate "$qps"

sleep 2

}


main() {

(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
(which jq) || (apt-get -y install jq)
(which socat) || (apt-get -y install socat)

pip install quart httpx matplotlib aiohttp

cd "$(dirname "$0")"

cd ..
# create sonnet-4x.txt so that we can sample 2048 tokens for input
echo "" > sonnet_4x.txt
for _ in {1..4}
do
cat sonnet.txt >> sonnet_4x.txt
done
cd disagg_benchmarks

rm -rf results
mkdir results

default_output_len=6

export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')

launch_chunked_prefill
for qps in 2 4 6 8; do
benchmark $qps $default_output_len chunked_prefill
done
kill_gpu_processes

launch_disagg_prefill
for qps in 2 4 6 8; do
benchmark $qps $default_output_len disagg_prefill
done
kill_gpu_processes

python3 visualize_benchmark_results.py

}


main "$@"
61 changes: 61 additions & 0 deletions benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os

import aiohttp
from quart import Quart, make_response, request

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)

app = Quart(__name__)


async def forward_request(url, data):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
async with session.post(url=url, json=data,
headers=headers) as response:
if response.status == 200:
# if response.headers.get('Transfer-Encoding') == 'chunked':
if True:
async for chunk_bytes in response.content.iter_chunked(
1024):
yield chunk_bytes
else:
content = await response.read()
yield content


@app.route('/v1/completions', methods=['POST'])
async def handle_request():
try:
original_request_data = await request.get_json()

prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request['max_tokens'] = 1

# finish prefill
async for _ in forward_request('http://localhost:8100/v1/completions',
prefill_request):
continue

# return decode
generator = forward_request('http://localhost:8200/v1/completions',
original_request_data)
response = await make_response(generator)
response.timeout = None

return response

except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))


if __name__ == '__main__':
app.run(port=8000)
Loading