diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f7e4af4f2af4..027cb218df5e 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -214,6 +214,7 @@ steps: - pytest -v -s v1/worker - pytest -v -s v1/structured_output - pytest -v -s v1/spec_decode + - pytest -v -s v1/kv_connector/unit - pytest -v -s v1/test_serial_utils.py - pytest -v -s v1/test_stats.py - pytest -v -s v1/test_utils.py diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 0ca2ced89148..f40d477a0036 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -870,7 +870,7 @@ def test_kv_connector_basic(): NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS) + NUM_MATCHED_NEW_TOKENS, False) ###################################################### # FIRST SET OF REQUESTS - External Hit Only @@ -981,7 +981,7 @@ def test_kv_connector_unable_to_allocate(): NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS) + NUM_MATCHED_NEW_TOKENS, False) # Create two requests. The second request will not be able to # allocate slots because it will not have enough blocks. @@ -1060,7 +1060,7 @@ def test_kv_connector_handles_preemption(): NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS) + NUM_MATCHED_NEW_TOKENS, False) # Create two requests. # Both can be scheduled at first, but the second request diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh new file mode 100755 index 000000000000..e90b72a7cf24 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -0,0 +1,171 @@ +#!/bin/bash +set -xe + +# Models to run +MODELS=( + "Qwen/Qwen3-0.6B" +) + +# Number of prefill and decode instances to create +NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1 +NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-2} # Default to 2 + +# Find the git repository root directory +GIT_ROOT=$(git rev-parse --show-toplevel) + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + +# Waits for vLLM to start. +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Function to clean up previous instances +cleanup_instances() { + echo "Cleaning up any running vLLM instances..." + pkill -f "vllm serve" || true + sleep 2 +} + +# Handle to get model-specific arguments for deepseek +get_model_args() { + local model_name=$1 + local extra_args="" + + if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then + extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code" + fi + + echo "$extra_args" +} + + +# Function to run tests for a specific model +run_tests_for_model() { + local model_name=$1 + echo "================================" + echo "Testing model: $model_name" + echo "================================" + + # Get model-specific arguments + local model_args=$(get_model_args "$model_name") + + # Arrays to store all hosts and ports + PREFILL_HOSTS=() + PREFILL_PORTS=() + DECODE_HOSTS=() + DECODE_PORTS=() + + # Start prefill instances + for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do + # Calculate GPU ID - we'll distribute across available GPUs + GPU_ID=$((i % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l))) + # Calculate port number (base port + instance number) + PORT=$((8100 + i)) + # Calculate side channel port + SIDE_CHANNEL_PORT=$((5559 + i)) + + echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" + + # Build the command with or without model-specific args + BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + --port $PORT \ + --enforce-eager \ + --disable-log-requests \ + --gpu-memory-utilization 0.2 \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Store host and port for proxy configuration + PREFILL_HOSTS+=("localhost") + PREFILL_PORTS+=($PORT) + done + + # Start decode instances + for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do + # Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs + GPU_ID=$(((i + NUM_PREFILL_INSTANCES) % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l))) + # Calculate port number (base port + instance number) + PORT=$((8200 + i)) + # Calculate side channel port + SIDE_CHANNEL_PORT=$((5659 + i)) + + echo "Starting decode instance $i on GPU $GPU_ID, port $PORT" + + # Build the command with or without model-specific args + BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + --port $PORT \ + --enforce-eager \ + --disable-log-requests \ + --gpu-memory-utilization 0.2 \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Store host and port for proxy configuration + DECODE_HOSTS+=("localhost") + DECODE_PORTS+=($PORT) + done + + # Wait for all instances to start + for PORT in "${PREFILL_PORTS[@]}"; do + echo "Waiting for prefill instance on port $PORT to start..." + wait_for_server $PORT + done + + for PORT in "${DECODE_PORTS[@]}"; do + echo "Waiting for decode instance on port $PORT to start..." + wait_for_server $PORT + done + + # Build the command for the proxy server with all the hosts and ports + PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192" + + # Add all prefill hosts and ports + PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}" + PROXY_CMD+=" --prefiller-ports ${PREFILL_PORTS[@]}" + + # Add all decode hosts and ports + PROXY_CMD+=" --decoder-hosts ${DECODE_HOSTS[@]}" + PROXY_CMD+=" --decoder-ports ${DECODE_PORTS[@]}" + + # Start the proxy server + echo "Starting proxy server with command: $PROXY_CMD" + $PROXY_CMD & + + # Wait for the proxy to start + sleep 5 + + # Run lm eval for this model + echo "Running tests for $model_name" + TEST_MODEL=$model_name python -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py + + # Clean up before running next model + cleanup_instances + sleep 3 +} + +# Run tests for each model +for model in "${MODELS[@]}"; do + run_tests_for_model "$model" +done + +echo "All tests completed!" diff --git a/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh b/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh new file mode 100644 index 000000000000..98903a176e28 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh @@ -0,0 +1,123 @@ +#!/bin/bash +set -xe + +# Models to run +MODELS=( + "Qwen/Qwen3-0.6B" +) + +# Find the git repository root directory +GIT_ROOT=$(git rev-parse --show-toplevel) + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + +# Waits for vLLM to start. +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Function to clean up previous instances +cleanup_instances() { + echo "Cleaning up any running vLLM instances..." + pkill -f "vllm serve" || true + sleep 2 +} + +# Handle to get model-specific arguments for deepseek +get_model_args() { + local model_name=$1 + local extra_args="" + + if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then + extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code" + fi + + echo "$extra_args" +} + + +# Function to run tests for a specific model +run_tests_for_model() { + local model_name=$1 + echo "================================" + echo "Testing model: $model_name" + echo "================================" + + # Get model-specific arguments + local model_args=$(get_model_args "$model_name") + + # Start prefill instance + PREFILL_PORT=8001 + + BASE_CMD="CUDA_VISIBLE_DEVICES=0 VLLM_NIXL_SIDE_CHANNEL_PORT=5559 vllm serve $model_name \ + --port $PREFILL_PORT \ + --enforce-eager \ + --disable-log-requests \ + --gpu-memory-utilization 0.2 \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Start decode instance + DECODE_PORT=8002 + + # Build the command with or without model-specific args + BASE_CMD="CUDA_VISIBLE_DEVICES=1 VLLM_NIXL_SIDE_CHANNEL_PORT=6000 vllm serve $model_name \ + --port $DECODE_PORT \ + --enforce-eager \ + --disable-log-requests \ + --gpu-memory-utilization 0.2 \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Wait for all instances to start + echo "Waiting for prefill instance on port $PORT to start..." + wait_for_server $PREFILL_PORT + echo "Waiting for decode instance on port $PORT to start..." + wait_for_server $DECODE_PORT + + # Build the command for the proxy server with all the hosts and ports + PROXY_PORT=8192 + PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port $PROXY_PORT" + PROXY_CMD+=" --prefiller-ports ${PREFILL_PORT}" + PROXY_CMD+=" --decoder-ports ${DECODE_PORT}" + # Start the proxy server + echo "Starting proxy server with command: $PROXY_CMD" + $PROXY_CMD & + + # Wait for the proxy to start + sleep 5 + + # Run lm eval for this model + echo "Running tests for $model_name" + PREFILL_PORT=$PREFILL_PORT DECODE_PORT=$DECODE_PORT PROXY_PORT=$PROXY_PORT python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py + + # Clean up before running next model + cleanup_instances + sleep 3 +} + +# Run tests for each model +for model in "${MODELS[@]}"; do + run_tests_for_model "$model" +done + +echo "All tests completed!" diff --git a/tests/v1/kv_connector/nixl_integration/test_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_accuracy.py new file mode 100644 index 000000000000..be2d84f3bb17 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/test_accuracy.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import lm_eval +import openai + +BASE_URL = "http://localhost:8192/v1" +NUM_CONCURRENT = 100 +TASK = "gsm8k" +FILTER = "exact_match,strict-match" +RTOL = 0.03 + +# Model-specific expected values +EXPECTED_VALUES = { + "Qwen/Qwen3-0.6B": 0.41, +} + +SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501 + +# Get model name from environment variable +MODEL_NAME = os.environ.get("TEST_MODEL", "Qwen/Qwen3-0.6B") + + +def run_simple_prompt(): + client = openai.OpenAI(api_key="EMPTY", base_url=BASE_URL) + completion = client.completions.create(model=MODEL_NAME, + prompt=SIMPLE_PROMPT) + + print("-" * 50) + print(f"Completion results for {MODEL_NAME}:") + print(completion) + print("-" * 50) + + +def test_accuracy(): + """Run the end to end accuracy test.""" + run_simple_prompt() + + model_args = (f"model={MODEL_NAME}," + f"base_url={BASE_URL}/completions," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + + results = lm_eval.simple_evaluate( + model="local-completions", + model_args=model_args, + tasks=TASK, + ) + + measured_value = results["results"][TASK][FILTER] + expected_value = EXPECTED_VALUES.get(MODEL_NAME) + + if expected_value is None: + print(f"Warning: No expected value found for {MODEL_NAME}. " + "Skipping accuracy check.") + print(f"Measured value: {measured_value}") + return + + assert (measured_value - RTOL < expected_value + and measured_value + RTOL > expected_value + ), f"Expected: {expected_value} | Measured: {measured_value}" diff --git a/tests/v1/kv_connector/nixl_integration/test_edge_cases.py b/tests/v1/kv_connector/nixl_integration/test_edge_cases.py new file mode 100644 index 000000000000..5363fbde0096 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/test_edge_cases.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import openai + +PREFILL_PORT = os.getenv("PREFILL_PORT", None) +DECODE_PORT = os.getenv("DECODE_PORT", None) +PROXY_PORT = os.getenv("PROXY_PORT", None) + +if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None: + raise ValueError( + "Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.") + +LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501 +PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501 +SHORT_PROMPT = "Red Hat is " + + +def test_edge_cases(): + # Set the OpenAI API key and base URL + decode_client = openai.OpenAI( + api_key="MY_KEY", + base_url=f"http://localhost:{DECODE_PORT}/v1", + ) + prefill_client = openai.OpenAI( + api_key="MY_KEY", + base_url=f"http://localhost:{PREFILL_PORT}/v1", + ) + proxy_client = openai.OpenAI( + api_key="MY_KEY", + base_url=f"http://localhost:{PROXY_PORT}/v1", + ) + + # Get the list of models + models = decode_client.models.list() + MODEL = models.data[0].id + + # (1) Check that we can handle a very short prompt, + # less than the length of the block size. + completion = proxy_client.completions.create(model=MODEL, + prompt=SHORT_PROMPT, + temperature=0) + proxy_response = completion.choices[0].text + completion = prefill_client.completions.create(model=MODEL, + prompt=SHORT_PROMPT, + temperature=0) + prefill_response = completion.choices[0].text + print(f"SMALL PROMPT: {proxy_response=}") + assert proxy_response == prefill_response + + # (2) Check that we can handle a full prefix cache + # hit on the D worker but not on the P worker. + # (2a): prime the D worker. + completion = decode_client.completions.create(model=MODEL, + prompt=PROMPT, + temperature=0) + decode_response = completion.choices[0].text + # (2b): send via the P/D setup + completion = proxy_client.completions.create(model=MODEL, + prompt=PROMPT, + temperature=0) + proxy_response = completion.choices[0].text + print(f"FULL CACHE HIT: {proxy_response=}") + assert proxy_response == decode_response + + # (3) Check that we can handle a partial prefix cache + # hit on the D worker. + completion = proxy_client.completions.create(model=MODEL, + prompt=LONG_PROMPT, + temperature=0) + proxy_response = completion.choices[0].text + completion = prefill_client.completions.create(model=MODEL, + prompt=LONG_PROMPT, + temperature=0) + prefill_response = completion.choices[0].text + print(f"PARTIAL CACHE HIT: {proxy_response=}") + assert proxy_response == prefill_response diff --git a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py new file mode 100644 index 000000000000..13071f581375 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py @@ -0,0 +1,260 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import itertools +import os +import uuid +from contextlib import asynccontextmanager + +import httpx +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize client pools for prefiller and decoder services + app.state.prefill_clients = [] + app.state.decode_clients = [] + + # Create prefill clients + for i, (host, port) in enumerate(global_args.prefiller_instances): + prefiller_base_url = f'http://{host}:{port}/v1' + app.state.prefill_clients.append({ + 'client': + httpx.AsyncClient(timeout=None, base_url=prefiller_base_url), + 'host': + host, + 'port': + port, + 'id': + i + }) + + # Create decode clients + for i, (host, port) in enumerate(global_args.decoder_instances): + decoder_base_url = f'http://{host}:{port}/v1' + app.state.decode_clients.append({ + 'client': + httpx.AsyncClient(timeout=None, base_url=decoder_base_url), + 'host': + host, + 'port': + port, + 'id': + i + }) + + # Initialize round-robin iterators + app.state.prefill_iterator = itertools.cycle( + range(len(app.state.prefill_clients))) + app.state.decode_iterator = itertools.cycle( + range(len(app.state.decode_clients))) + + print(f"Initialized {len(app.state.prefill_clients)} prefill clients " + f"and {len(app.state.decode_clients)} decode clients.") + + yield + + # Shutdown: Close all clients + for client_info in app.state.prefill_clients: + await client_info['client'].aclose() + + for client_info in app.state.decode_clients: + await client_info['client'].aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + + # For prefiller instances + parser.add_argument("--prefiller-hosts", + "--prefiller-host", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--prefiller-ports", + "--prefiller-port", + type=int, + nargs="+", + default=[8100]) + + # For decoder instances + parser.add_argument("--decoder-hosts", + "--decoder-host", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--decoder-ports", + "--decoder-port", + type=int, + nargs="+", + default=[8200]) + + args = parser.parse_args() + + # Validate and pair hosts with ports + if len(args.prefiller_hosts) != len(args.prefiller_ports): + raise ValueError( + "Number of prefiller hosts must match number of prefiller ports") + + if len(args.decoder_hosts) != len(args.decoder_ports): + raise ValueError( + "Number of decoder hosts must match number of decoder ports") + + # Create tuples of (host, port) for each service type + args.prefiller_instances = list( + zip(args.prefiller_hosts, args.prefiller_ports)) + args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports)) + + return args + + +def get_next_client(app, service_type: str): + """ + Get the next client in round-robin fashion. + + Args: + app: The FastAPI app instance + service_type: Either 'prefill' or 'decode' + + Returns: + The next client to use + """ + if service_type == 'prefill': + client_idx = next(app.state.prefill_iterator) + return app.state.prefill_clients[client_idx] + elif service_type == 'decode': + client_idx = next(app.state.decode_iterator) + return app.state.decode_clients[client_idx] + else: + raise ValueError(f"Unknown service type: {service_type}") + + +async def send_request_to_service(client_info: dict, endpoint: str, + req_data: dict, request_id: str): + """ + Send a request to a service using a client from the pool. + """ + req_data = req_data.copy() + req_data['kv_transfer_params'] = { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None + } + req_data["stream"] = False + req_data["max_tokens"] = 1 + if "stream_options" in req_data: + del req_data["stream_options"] + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + + response = await client_info['client'].post(endpoint, + json=req_data, + headers=headers) + response.raise_for_status() + + return response + + +async def stream_service_response(client_info: dict, endpoint: str, + req_data: dict, request_id: str): + """ + Asynchronously stream response from a service using a client from the pool. + """ + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + + async with client_info['client'].stream("POST", + endpoint, + json=req_data, + headers=headers) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + try: + req_data = await request.json() + request_id = str(uuid.uuid4()) + + # Get the next prefill client in round-robin fashion + prefill_client_info = get_next_client(request.app, 'prefill') + + # Send request to prefill service + response = await send_request_to_service(prefill_client_info, + "/completions", req_data, + request_id) + + # Extract the needed fields + response_json = response.json() + kv_transfer_params = response_json.get('kv_transfer_params', {}) + if kv_transfer_params: + req_data["kv_transfer_params"] = kv_transfer_params + + # Get the next decode client in round-robin fashion + decode_client_info = get_next_client(request.app, 'decode') + + logger.debug("Using %s %s", prefill_client_info, decode_client_info) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(decode_client_info, + "/completions", + req_data, + request_id=request_id): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.get("/healthcheck") +async def healthcheck(): + """Simple endpoint to check if the server is running.""" + return { + "status": "ok", + "prefill_instances": len(app.state.prefill_clients), + "decode_instances": len(app.state.decode_clients) + } + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/tests/v1/kv_connector/unit/__init__.py b/tests/v1/kv_connector/unit/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py new file mode 100644 index 000000000000..9b2a720c11c4 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlConnectorMetadata) + +from .utils import create_request, create_scheduler, create_vllm_config + + +def test_basic_inferface(): + """Unit test for basic NixlConnector interface functionality.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + request_id = request.request_id + + scheduler.add_request(request) + + # Remote Prefill, triggers NixlConnectorMetdata. + scheduler_output = scheduler.schedule() + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, NixlConnectorMetadata) + + assert len(kv_connector_metadata.requests) == 1 + assert request_id in kv_connector_metadata.requests + req_meta = kv_connector_metadata.requests[request_id] + + for block_id, block in zip( + req_meta.local_block_ids, scheduler.kv_cache_manager. + single_type_manager.req_to_blocks[request_id]): + assert block_id == block.block_id + + +def test_prompt_less_than_block_size(): + """ + Test that we can handle case where prompt is < block. + + In this case, the P worker will send empty remote_block_ids. + The D worker should not schedule an async read in this case, + since there is nothing to pull. + """ + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # Half of a block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_TOKENS = int(BLOCK_SIZE * 0.5) + + # Request will have 0 remote blocks. + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + num_remote_blocks=0) + scheduler.add_request(request) + scheduler_output = scheduler.schedule() + + # This request should not have to read async. + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, NixlConnectorMetadata) + assert len(kv_connector_metadata.requests) == 0 + + # This request should be scheduled regularly. + assert len(scheduler_output.scheduled_new_reqs) == 1 diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py new file mode 100644 index 000000000000..77098140343a --- /dev/null +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: Apache-2.0 +import copy + +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.request import FinishReason, RequestStatus + +from .utils import (assert_scheduler_empty, create_model_runner_output, + create_request, create_scheduler, create_vllm_config) + + +def test_basic_lifecycle(): + """Test lifecycle of a Remote Decode request.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request(request_id=1, + max_tokens=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): Prefill. + # (1a): schedule() + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + + # (1b): execute_model() + model_runner_output = create_model_runner_output(reqs=[request]) + + # (1c): update_from_output() + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + + # Ensure the request is finished after 1 tokens. + assert request.is_finished() + assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED + output = engine_core_outputs.outputs[0] + assert output.finish_reason == FinishReason.LENGTH + assert output.kv_transfer_params is not None + + # Request freed in Scheduler and in Persistent Batch ... + assert request_id in scheduler.finished_req_ids + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 0 + + # ... but blocks should not be freed. + blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ + request_id] + for block in blocks: + assert block.ref_cnt == 1 + + # STEP (2): Send Finished to PB. + # (2a): schedule() - pass finished request to PB. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 0 + assert len(scheduler_output.finished_req_ids) == 1 + assert request_id in scheduler_output.finished_req_ids + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler.finished_req_ids) == 0 + + # (2b): execute_model() + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + + # (2c): update_from_output() + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP (3): Finished sending. + # (3a): schedule() - pass finished request to PB. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 0 + assert len(scheduler_output.finished_req_ids) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler.finished_req_ids) == 0 + + # (3b): execute_model() + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_sending = [request_id] + + # (3c): update_from_output() + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Confirm we do not have any memory leaks after req lifecycle. + assert_scheduler_empty(scheduler) + + +def test_short_prompt_lifecycle(): + """Test lifecycle of a Remote Decode request with short prompt.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # Not enough tokens for full block. + NUM_TOKENS = vllm_config.cache_config.block_size // 2 + request = create_request(request_id=1, + max_tokens=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request) + + # STEP (1): Prefill. + # (1a): schedule() + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + + # (1b): execute_model() + model_runner_output = create_model_runner_output(reqs=[request]) + + # (1c): update_from_output() + # Since tokens < block_size, there will be no kv xfer. + # So this should be cleaned up immediately. + _ = scheduler.update_from_output(scheduler_output, model_runner_output) + + # Confirm we do not have any memory leaks after req lifecycle. + # We need one more call to schedule() to clear data for persistent batch. + _ = scheduler.schedule() + assert_scheduler_empty(scheduler) + + +def test_prefix_cache_lifecycle(): + """Test that remote decode params still works with a prefix cache hit.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # Prime the KVCache. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 3 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS) + + scheduler.add_request(request_normal) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal], + use_eos=True) + scheduler.update_from_output(scheduler_output, model_runner_output) + scheduler.schedule() + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + + ##################### + # Actual Test: confirm we send all blocks. + + # Step (1): Send the KV Transfer. + NUM_EXTERNAL_FULL_BLOCKS -= 1 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request_remote = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_remote]) + eco = scheduler.update_from_output(scheduler_output, model_runner_output) + kv_transfer_params = eco.outputs[0].kv_transfer_params + + # Ensure we send all block ids, even if there is a cache hit. + assert (len( + kv_transfer_params["remote_block_ids"]) == NUM_EXTERNAL_FULL_BLOCKS) + + # STEP (2): Ensure it is freed. + scheduler_output = scheduler.schedule() + scheduler.schedule() + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_sending = [request_remote.request_id] + scheduler.update_from_output(scheduler_output, model_runner_output) + _ = scheduler.schedule() + assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py new file mode 100644 index 000000000000..fc4928f9ebd1 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -0,0 +1,342 @@ +# SPDX-License-Identifier: Apache-2.0 +import copy + +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.request import FinishReason, RequestStatus + +from .utils import (assert_scheduler_empty, create_model_runner_output, + create_request, create_scheduler, create_vllm_config) + + +def test_basic_lifecycle(): + """Test lifecycle of a remote prefill.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + START_FREE_BLOCK_QUEUE_SIZE = ( + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): + # (1a): schedule() + scheduler_output = scheduler.schedule() + + # Nothing running and empty scheduler output. + assert len(scheduler.running) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler_output.num_scheduled_tokens) == 0 + assert scheduler_output.total_num_scheduled_tokens == 0 + + # Req waiting for KVs with no computed/scheduled toks ... + assert len(scheduler.waiting) == 1 + assert request in scheduler.waiting + assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS) + assert (request.num_computed_tokens == 0) + + # ... but should have (uncached) blocks allocated to it. + block_pool = scheduler.kv_cache_manager.block_pool + assert (block_pool.free_block_queue.num_free_blocks + < START_FREE_BLOCK_QUEUE_SIZE) + assert len(block_pool.cached_block_hash_to_block) == 0 + blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ + request_id] + for block in blocks: + assert block._block_hash is None + + # (1b): forward() + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + + # (1c): update_from_output() + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + assert len(engine_core_outputs.outputs) == 0 + + # STEP (2): + # (2a): schedule(): nothing happens! + scheduler_output = scheduler.schedule() + assert len(scheduler.waiting) == 1 + assert len(scheduler.running) == 0 + + # (2b): forward(): request finishes recv. + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_recving = [request_id] + + # (2c): update_from_output(): + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + assert len(scheduler.waiting) == 1 + assert (request_id in scheduler.finished_recving_kv_req_ids) + + # STEP (3): + # (3a): schedule(): this should actually schedule. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + + # Confirm the block are actually allocated. + num_hashed_blocks = 0 + blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ + request_id] + for block in blocks: + assert block.ref_cnt == 1 + num_hashed_blocks += (1 if block._block_hash is not None else 0) + assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS + + # Confirm the rest of the prompt is scheduled in this step. + scheduled_req = scheduler_output.scheduled_new_reqs[0] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id] + num_computed_tokens = scheduled_req.num_computed_tokens + total_prompt_tokens = len(scheduled_req.prompt_token_ids) + assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens) + + # (3b): execute_model() + model_runner_output = create_model_runner_output([request]) + # (3c): update_from_output() + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Step (4): Hit EOS. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output([request], use_eos=True) + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + scheduler.schedule() + + outputs = engine_core_outputs.outputs + assert len(outputs) == 1 + output = outputs[0] + assert output.finish_reason == FinishReason.STOP + assert_scheduler_empty(scheduler) + + +def test_interleaved_lifecycle(): + """Test Remote Prefills Work Well With Other Requests.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request_remote = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + request_local_a = create_request( + request_id=2, + num_tokens=NUM_TOKENS, + ) + request_local_b = create_request( + request_id=3, + num_tokens=NUM_TOKENS, + ) + + # STEP 1: Regular request is running. + scheduler.add_request(request_local_a) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + + model_runner_output = create_model_runner_output([request_local_a]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP 2: Add a local and remote request. + scheduler.add_request(request_local_b) + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + assert len(scheduler_output.scheduled_cached_reqs) == 1 + + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP 3: continue running, KVs not arrived yet. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + model_runner_output = create_model_runner_output( + reqs=[request_local_a, request_local_b]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + # STEP 4: KVs arrive. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b], + finished_recving=[request_remote.request_id]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP 5: RECVed KVs are sent to ModelRunner. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 3 + assert len(scheduler.waiting) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 1 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b, request_remote]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP 6: Hit EOS and free. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b, request_remote], + use_eos=True, + ) + scheduler.update_from_output(scheduler_output, model_runner_output) + scheduler.schedule() + assert_scheduler_empty(scheduler) + + +def test_no_spurious_prefix_caching(): + """ + With P/D, blocks can be allocated but uncomputed for + multiple engine steps. This test confirms that we do + not accidentally have cache hits against uncomputed + blocks. + """ + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 and a half full external blocks. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + # Both of these requests have prompts like [1,1,1,1,1, ...] + request_remote = create_request( + request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + use_all_1s_for_prompt_tokens=True, + ) + + request_local = create_request( + request_id=2, + num_tokens=NUM_TOKENS, + do_remote_prefill=False, + use_all_1s_for_prompt_tokens=True, + ) + + # Schedule the remote prefill request. This should not + # cause any blocks to be cached. + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + assert len(scheduler.waiting) == 1 + + # Schedule the local prefill request. This should + # cause blocks to be cached, but separately from + scheduler.add_request(request_local) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + local_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ + request_local.request_id] + remote_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ # noqa: E501 + request_remote.request_id] + + # Local should have cached blocks (but not all due to preallocate). + num_hashed_blocks = 0 + for block in local_blocks: + assert block.ref_cnt == 1 + num_hashed_blocks += (1 if block._block_hash is not None else 0) + assert num_hashed_blocks > 0 + + # Remote blocks should not be cached. + for block in remote_blocks: + assert block.ref_cnt == 1 + assert block._block_hash is None + + +def test_full_block_prompt(): + """Test that we handle a prompt that is the full block size.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): Initialize a recv. + scheduler_output = scheduler.schedule() + # All blocks should be allocated. + num_blocks = len(scheduler.kv_cache_manager.single_type_manager. + req_to_blocks[request_id]) + assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + scheduler.update_from_output(scheduler_output, model_runner_output) + + # # STEP (2): Recv. + scheduler_output = scheduler.schedule() + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_recving = [request_id] + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.waiting) == 1 + assert (request_id in scheduler.finished_recving_kv_req_ids) + + # # STEP (3): Run as usual. + scheduler_output = scheduler.schedule() + + # We need to recompute the final token of the prompt to generate + # the first new token, so we should not have a new block. + num_blocks = len(scheduler.kv_cache_manager.single_type_manager. + req_to_blocks[request_id]) + assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS + assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens == + NUM_TOKENS - 1) + assert (scheduler_output.num_scheduled_tokens[request_id] == 1) + + model_runner_output = create_model_runner_output([request]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # # Step (4): Hit EOS. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output([request], use_eos=True) + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + scheduler.schedule() + + outputs = engine_core_outputs.outputs + assert len(outputs) == 1 + output = outputs[0] + assert output.finish_reason == FinishReason.STOP + assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py new file mode 100644 index 000000000000..8a7d7bdd83da --- /dev/null +++ b/tests/v1/kv_connector/unit/utils.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch + +from vllm import SamplingParams +from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, + ModelConfig, SchedulerConfig, VllmConfig) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlKVTransferParams) +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager + +EOS_TOKEN_ID = 50256 + + +def assert_scheduler_empty(scheduler: Scheduler): + """Confirm the scheduler is "empty" - i.e. no leaks.""" + # Scheduler Metadata. + assert len(scheduler.requests) == 0 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 0 + assert len(scheduler.finished_req_ids) == 0 + assert len(scheduler.finished_recving_kv_req_ids) == 0 + assert len(scheduler._cached_reqs_data) == 0 + + # EncoderCacheManager. + assert len(scheduler.encoder_cache_manager.freed) == 0 + assert len(scheduler.encoder_cache_manager.cached) == 0 + + # KVCache Manager. + assert len( + scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0 + assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 + assert len( + scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0 + num_free_blocks = ( + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + assert num_free_blocks == ( + scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + + # NOTE(rob): just the ref count on blocks will be 0. The hash + # value, etc will remain since we lazily evict for prefix cache. + for block in scheduler.kv_cache_manager.block_pool.blocks: + assert block.ref_cnt == 0 + + +def create_vllm_config( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 64, + block_size: int = 16, +) -> VllmConfig: + """Initialize VllmConfig For Testing.""" + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_num_batched_tokens, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + # Cache config, optionally force APC + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + enable_prefix_caching=True, + ) + kv_transfer_config = KVTransferConfig( + kv_connector="NixlConnector", + kv_role="kv_both", + ) + return VllmConfig(scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + device_config=DeviceConfig("cpu")) + + +def create_scheduler( + vllm_config: VllmConfig, + num_blocks: int = 10000, +) -> Scheduler: + """Initialize Scheduler For Testing.""" + block_size = vllm_config.cache_config.block_size + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, # A large number of blocks to hold all requests + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, torch.float32, + False)) + ], + ) + vllm_config.cache_config.num_gpu_blocks = num_blocks + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + +def create_request( + request_id: int, + num_tokens: int = 10, + max_tokens: int = 16, + do_remote_decode: bool = False, + do_remote_prefill: bool = False, + use_all_1s_for_prompt_tokens: bool = False, + num_remote_blocks: int = 3, +) -> Request: + """Make dummy request for testing.""" + + if do_remote_decode: + assert not do_remote_prefill + kv_transfer_params = NixlKVTransferParams(do_remote_prefill=False, + do_remote_decode=True) + elif do_remote_prefill: + kv_transfer_params = NixlKVTransferParams( + do_remote_prefill=True, + do_remote_decode=False, + remote_engine_id="my-engine-id", + remote_block_ids=list(range(num_remote_blocks)), + remote_host="my-host", + remote_port=1234) + else: + kv_transfer_params = None + + max_tokens = 1 if do_remote_decode else max_tokens + sampling_params = SamplingParams(max_tokens=max_tokens) + + if use_all_1s_for_prompt_tokens: + prompt_token_ids = [1] * num_tokens + else: + prompt_token_ids = [i * request_id for i in range(num_tokens)] + + req = Request( + request_id=f"id-{request_id}", + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + multi_modal_inputs=None, + multi_modal_placeholders=None, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + arrival_time=0, + ) + req.kv_transfer_params = kv_transfer_params + return req + + +def create_model_runner_output( + reqs: list[Request], + finished_sending: Optional[list[str]] = None, + finished_recving: Optional[list[str]] = None, + use_eos: bool = False, +) -> ModelRunnerOutput: + """Make dummy model runner output for testing.""" + + # Make request data. + req_ids = [req.request_id for req in reqs] + req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)} + + # Make sampled tokens. + sampled_token = EOS_TOKEN_ID if use_eos else 0 + sampled_token_ids = [[sampled_token] for _ in req_ids] + + # Make output data structure. + return ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_id_to_index, + sampled_token_ids=sampled_token_ids, + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + finished_sending=finished_sending, + finished_recving=finished_recving, + ) diff --git a/vllm/config.py b/vllm/config.py index 4a503665503a..c6b97bbdcd66 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -8,6 +8,7 @@ import json import re import textwrap +import uuid import warnings from collections import Counter from contextlib import contextmanager @@ -3438,6 +3439,9 @@ class KVTransferConfig: """The KV connector for vLLM to transmit KV caches between vLLM instances. """ + engine_id: str = str(uuid.uuid4()) + """The engine id for KV transfers.""" + kv_buffer_device: Optional[str] = "cuda" """The device used by kv connector to buffer the KV cache. Currently only support 'cuda'.""" @@ -3448,7 +3452,7 @@ class KVTransferConfig: kv_role: Optional[KVRole] = None """Whether this vLLM instance produces, consumes KV cache, or both. Choices - are 'kv_producer', 'kv_consumer', and 'both'.""" + are 'kv_producer', 'kv_consumer', and 'kv_both'.""" kv_rank: Optional[int] = None """The rank of this vLLM instance in the KV cache transfer. Typical value: diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 6532c101a4f6..54cb1871db3c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -105,3 +105,8 @@ def create_connector_v1( "LMCacheConnectorV1", "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", "LMCacheConnectorV1") + +KVConnectorFactory.register_connector( + "NixlConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector", + "NixlConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py index a017b140e090..43181ab79afc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py @@ -1,8 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorRole) + KVConnectorBase_V1, KVConnectorRole, KVTransferParams) -__all__ = [ - "KVConnectorRole", - "KVConnectorBase_V1", -] +__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "KVTransferParams"] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 95967d2ca919..2ff61e8a400f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -23,7 +23,7 @@ import enum from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional import torch @@ -34,6 +34,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request logger = init_logger(__name__) @@ -47,12 +48,34 @@ class KVConnectorRole(enum.Enum): WORKER = 1 +class KVTransferParams: + """ + Abstract KVTransferParams used to send KVTransfer + parameters between instances of vLLM. + + Specific instances of KVConnector customize this + method for serializing / deserializing msgs sent + via the HTTP protocol. + """ + + @staticmethod + def from_raw_dict( + raw_dict: Optional[dict[str, + Any]]) -> Optional["KVTransferParams"]: + return None + + @dataclass class KVConnectorMetadata: + """ + Abstract Metadata used to communicate between the + Scheduler KVConnector and Worker KVConnector. + """ pass class KVConnectorBase_V1(ABC): + _KVTransferParams = KVTransferParams def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): logger.warning( @@ -66,6 +89,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def role(self) -> KVConnectorRole: return self._role + # ============================== + # Worker-side methods + # ============================== + def bind_connector_metadata( self, connector_metadata: KVConnectorMetadata) -> None: """Set the connector metadata from the scheduler. @@ -97,9 +124,15 @@ def _get_connector_metadata(self) -> KVConnectorMetadata: """ return self._connector_metadata - # ============================== - # Worker-side methods - # ============================== + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """ + Initialize with the KV caches. Useful for pre-registering the + KV Caches in the KVConnector (e.g. for NIXL). + + Args: kv_caches: + dictionary of layer names, kv cache + """ + return @abstractmethod def start_load_kv(self, forward_context: "ForwardContext", @@ -162,15 +195,37 @@ def wait_for_save(self): """ pass + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + + Returns: + ids of requests that have finished asynchronous (recving, sending). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + return None, None + # ============================== # Scheduler-side methods # ============================== + + def set_kv_transfer_params(self, request: "Request"): + """Parse raw KV Transfer params.""" + assert request.kv_transfer_params is None + kv_transfer_params = self._KVTransferParams.from_raw_dict( + request.raw_kv_transfer_params) + request.kv_transfer_params = kv_transfer_params + @abstractmethod def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, - ) -> int: + ) -> tuple[int, bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. @@ -181,13 +236,16 @@ def get_num_new_matched_tokens( computed tokens for this request Returns: - the number of tokens that can be loaded from the - external KV cache beyond what is already computed. + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if external KV cache tokens will be loaded + asynchronously (between scheduler steps). """ pass @abstractmethod def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", num_external_tokens: int): """ Update KVConnector state after block allocation. @@ -207,3 +265,20 @@ def build_connector_meta( scheduler_output (SchedulerOutput): the scheduler output object. """ pass + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its blocks are freed. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + return False, None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index e07f185f0dd8..2cb68dc1ff67 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request logger = init_logger(__name__) @@ -92,7 +93,7 @@ def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, - ) -> int: + ) -> tuple[int, bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. @@ -107,9 +108,10 @@ def get_num_new_matched_tokens( external KV cache beyond what is already computed. """ return self._lmcache_engine.get_num_new_matched_tokens( - request, num_computed_tokens) + request, num_computed_tokens), False def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", num_external_tokens: int): """ Update KVConnector state after block allocation. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py new file mode 100644 index 000000000000..d26184982270 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -0,0 +1,805 @@ +# SPDX-License-Identifier: Apache-2.0 +import contextlib +import math +import threading +import time +import uuid +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Iterator + +import msgspec +import torch +import zmq +from typing_extensions import Optional + +from vllm import envs +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, KVTransferParams) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group) +from vllm.logger import init_logger +from vllm.utils import round_down +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import RequestStatus + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +GET_META_MSG = b"get_meta_msg" + +logger = init_logger(__name__) + +# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used +try: + from nixl._api import nixl_agent as NixlWrapper + logger.info("NIXL is available") +except ImportError: + logger.warning("NIXL is not available") + NixlWrapper = None + + +@dataclass +class NixlKVTransferParams(KVTransferParams): + + def __init__( + self, + do_remote_prefill: bool, + do_remote_decode: bool, + remote_block_ids: Optional[list[int]] = None, + remote_host: Optional[str] = None, + remote_port: Optional[int] = None, + remote_engine_id: Optional[str] = None, + ): + self.do_remote_prefill = do_remote_prefill + self.do_remote_decode = do_remote_decode + self.remote_block_ids = remote_block_ids + self.remote_host = remote_host + self.remote_port = remote_port + self.remote_engine_id = remote_engine_id + + @staticmethod + def from_raw_dict( + raw_dict: Optional[dict[str, + Any]]) -> Optional["NixlKVTransferParams"]: + + # If no raw transfer params passed, return None. + if raw_dict is None: + return None + + # Validate the request is formatted properly. + if (("do_remote_prefill" not in raw_dict) + or ("do_remote_decode" not in raw_dict) + or ("remote_block_ids" not in raw_dict) + or ("remote_host" not in raw_dict) + or ("remote_port" not in raw_dict) + or ("remote_engine_id" not in raw_dict)): + logger.warning( + "Got invalid KVTransferParams: %s. This " + "request will not utilize KVTransfer", raw_dict) + return None + + return NixlKVTransferParams( + do_remote_prefill=raw_dict["do_remote_prefill"], + do_remote_decode=raw_dict["do_remote_decode"], + remote_block_ids=raw_dict["remote_block_ids"], + remote_host=raw_dict["remote_host"], + remote_port=raw_dict["remote_port"], + remote_engine_id=raw_dict["remote_engine_id"], + ) + + +class NixlAgentMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): + engine_id: str + agent_metadata: bytes + kv_caches_base_addr: list[int] + num_blocks: int + + +@dataclass +class ReqMeta: + local_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_port: int + remote_engine_id: str + + +class NixlConnectorMetadata(KVConnectorMetadata): + + def __init__(self): + self.requests: dict[str, ReqMeta] = {} + + def add_new_req( + self, + request_id: str, + local_block_ids: list[int], + kv_transfer_params: NixlKVTransferParams, + ): + assert request_id not in self.requests + assert kv_transfer_params.remote_block_ids is not None + assert kv_transfer_params.remote_engine_id is not None + assert kv_transfer_params.remote_host is not None + assert kv_transfer_params.remote_port is not None + + self.requests[request_id] = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params.remote_block_ids, + remote_engine_id=kv_transfer_params.remote_engine_id, + remote_host=kv_transfer_params.remote_host, + remote_port=kv_transfer_params.remote_port, + ) + + +class NixlConnector(KVConnectorBase_V1): + _KVTransferParams: type[NixlKVTransferParams] = NixlKVTransferParams + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + assert vllm_config.kv_transfer_config is not None + self.engine_id = vllm_config.kv_transfer_config.engine_id + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler : Optional[NixlConnectorScheduler] = \ + NixlConnectorScheduler(vllm_config, str(self.engine_id)) + self.connector_worker: Optional[NixlConnectorWorker] = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = NixlConnectorWorker(str(self.engine_id)) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, NixlConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + """NixlConnector does not do layerwise saving.""" + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """NixlConnector does not save explicitly.""" + pass + + def wait_for_save(self): + """NixlConnector does not save explicitly.""" + pass + + +class NixlConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.engine_id = engine_id + logger.info("Initializing NIXL Scheduler %s", engine_id) + + # Requests that need to start recv. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + """ + For remote prefill, pull all prompt blocks from remote + asynchronously relative to engine execution. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + Returns: + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded + asynchronously (between scheduler steps). + """ + + # No KVTransfer for this request. + if request.kv_transfer_params is None: + return 0, False + assert isinstance(request.kv_transfer_params, NixlKVTransferParams) + + # Remote prefill: get all prompt blocks from remote. + if request.kv_transfer_params.do_remote_prefill: + assert num_computed_tokens % self.block_size == 0 + rounded_num_prompt_tokens = round_down( + len(request.prompt_token_ids), self.block_size) + count = max(rounded_num_prompt_tokens - num_computed_tokens, 0) + return count, count > 0 + + return 0, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + if request.kv_transfer_params is None: + return + + assert isinstance(request.kv_transfer_params, NixlKVTransferParams) + if request.kv_transfer_params.do_remote_prefill: + # NOTE(rob): if prompt < block_size, no remote blocks + # since the remote only sends fully computed blocks, so + # skip recving for this request. num_external_tokens + # should be 0 if there are no remote blocks. + if request.kv_transfer_params.remote_block_ids: + # Get unhashed blocks to pull from remote. + self._reqs_need_recv[request.request_id] = ( + request, blocks.get_unhashed_block_ids()) + else: + assert num_external_tokens == 0 + # Only trigger 1 KV transfer per request. + request.kv_transfer_params.do_remote_prefill = False + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = NixlConnectorMetadata() + + # Loop through scheduled reqs and convert to ReqMeta. + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert isinstance(req.kv_transfer_params, NixlKVTransferParams) + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + + # Clear the list once workers start the transfers + self._reqs_need_recv.clear() + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + + if request.kv_transfer_params is None: + return False, None + assert isinstance(request.kv_transfer_params, NixlKVTransferParams) + + if ((not request.kv_transfer_params.do_remote_decode) + or (request.status != RequestStatus.FINISHED_LENGTH_CAPPED)): + return False, None + + # Get computed blocks. + all_full = request.num_computed_tokens % self.block_size == 0 + computed_block_ids = (block_ids if all_full else block_ids[:-1]) + + # If prompt < block_size, no xfer so free blocks immediately. + delay_free_blocks = len(computed_block_ids) > 0 + + return delay_free_blocks, NixlKVTransferParams( + do_remote_prefill=True, + do_remote_decode=False, + remote_block_ids=computed_block_ids, + remote_engine_id=self.engine_id, + remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST, + remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT, + ).__dict__ + + +class NixlConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, engine_id: str): + if NixlWrapper is None: + logger.error("NIXL is not available") + raise RuntimeError("NIXL is not available") + logger.info("Initializing NIXL wrapper") + logger.info("Initializing NIXL worker %s", engine_id) + + # Agent. + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + # Map of engine_id -> agent_name. + self._remote_agents: dict[str, str] = {} + + # Metadata. + self.engine_id = engine_id + self.rank = get_tensor_model_parallel_rank() + self.world_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group() + + # KV Caches and nixl tracking data. + self.kv_caches: dict[str, torch.Tensor] = {} + + # Map of engine_id -> kv_caches_base_addr + self.kv_caches_base_addr: dict[str, list[int]] = {} + + # Number of NIXL regions. Currently one region per cache + # (so 1 per layer for MLA, otherwise 2 per layer) + self.num_regions = 0 + + # nixl_prepped_dlist_handle (int). + self.src_xfer_side_handle: int = 0 + # Map of engine_id -> nixl_prepped_dlist_handle (int)]. + self.dst_xfer_side_handles: dict[str, int] = {} + + # Map of engine_id -> num_blocks. + self.dst_num_blocks: dict[str, int] = {} + self._registered_descs: list[Any] = [] + + # In progress transfers. + # [req_id -> list[handle]] + self._recving_transfers: defaultdict[str, list[Any]] = defaultdict( + list[Any]) + + # Complete transfer tracker. Used by the rank 0 to track finished + # transactions on ranks 1 to N-1. + # [req_id -> count] + self._done_recving_count: defaultdict[str, + int] = defaultdict(lambda: 0) + self._done_sending_count: defaultdict[str, + int] = defaultdict(lambda: 0) + + # Background thread for establishing new connections. + self._nixl_handshake_listener_t: Optional[threading.Thread] = None + + @staticmethod + def _nixl_handshake_listener(metadata: NixlAgentMetadata, + ready_event: threading.Event, rank: int): + """Background thread for getting new NIXL handshakes.""" + # NOTE(rob): this is a simple implementation. We will move + # to a better approach like an ETCD server in the future. + + # NOTE(rob): to support heterogeneous TP, we will have to + # move this into the scheduler rather than worker, since + # each rank needs the metadata of all other ranks (whereas + # in this setup, each rank only gets one other rank's meta. + + encoder = msgspec.msgpack.Encoder() + encoded_data = encoder.encode(metadata) + size_in_bytes = len(encoded_data) + logger.debug("Size of encoded NixlAgentMetadata: %s bytes", + str(size_in_bytes)) + + # Listen for new requests for metadata. + host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + # NOTE(rob): we need each rank to have a unique port. This + # hack to keeps us moving. We will switch when moving to etcd + # or where we have a single ZMQ socket in the scheduler. + port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank + path = f"tcp://{host}:{port}" + logger.debug("Starting listening on path: %s", path) + with zmq_ctx(zmq.ROUTER, path) as sock: + ready_event.set() + while True: + identity, _, msg = sock.recv_multipart() + if msg != GET_META_MSG: + logger.warning( + "Connection listener got unexpected message %s", msg) + sock.send_multipart((identity, b"", encoded_data)) + + def _nixl_handshake(self, host: str, port: int): + """Do a NIXL handshake with a remote instance.""" + + start_time = time.perf_counter() + # NOTE(rob): we need each rank to have a unique port. This is + # a hack to keep us moving. We will switch when moving to etcd + # or where we have a single ZMQ socket in the scheduler. + path = f"tcp://{host}:{port + self.rank}" + logger.debug("Querying metadata on path: %s", path) + with zmq_ctx(zmq.REQ, path) as sock: + # Send query for the request. + sock.send(GET_META_MSG) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + + # Register Remote agent. + self.add_remote_agent(metadata) + setup_agent_time = time.perf_counter() + + logger.debug("NIXL handshake: get metadata took: %s", + got_metadata_time - start_time) + logger.debug("NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time) + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register the KV Cache data in nixl.""" + + _, first_kv_cache = next(iter(kv_caches.items())) + kv_elem_size = first_kv_cache.element_size() + + # TODO(tms): Find a more robust way to detect and handle MLA + use_mla = len(first_kv_cache.shape) == 3 + if use_mla: + # MLA case. + self.num_blocks = first_kv_cache.shape[0] + block_rank = 2 # [block_size, latent_dim] + block_shape = first_kv_cache.shape[-block_rank:] + else: + # [2 (k and v), num_blocks, ...] + self.num_blocks = first_kv_cache.shape[1] + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + + # TODO(tms): self.block_len needs to be per-layer for sliding window, + # hybrid attn, etc + self.block_len = kv_elem_size * math.prod(block_shape) + + logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla, + first_kv_cache.shape) + logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks, + block_shape) + logger.debug("Per layer kv cache size: %s", first_kv_cache.shape) + self.dst_num_blocks[self.engine_id] = self.num_blocks + self.kv_caches = kv_caches + kv_caches_base_addr = [] + caches_data = [] + + # Note(tms): I modified this from the original region setup code. + # K and V are now in different regions. Advantage is that we can + # elegantly support MLA and any cases where the K and V tensors + # are non-contiguous (it's not locally guaranteed that they will be) + # Disadvantage is that the encoded NixlAgentMetadata is now larger + # (roughly 8KB vs 5KB). + for cache_or_caches in kv_caches.values(): + # Normalize to always be a list of caches + cache_list = [cache_or_caches] if use_mla else cache_or_caches + for cache in cache_list: + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len + caches_data.append((base_addr, region_len, self.rank, "")) + kv_caches_base_addr.append(base_addr) + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + self.num_regions = len(caches_data) + + descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") + logger.debug("Registering descs: %s", caches_data) + self.nixl_wrapper.register_memory(descs) + logger.debug("Done registering descs") + + self._registered_descs.append(descs) + + # After KV Caches registered, listen for new connections. + metadata = NixlAgentMetadata( + engine_id=self.engine_id, + agent_metadata=self.nixl_wrapper.get_agent_metadata(), + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], + num_blocks=self.num_blocks, + ) + ready_event = threading.Event() + self._nixl_handshake_listener_t = threading.Thread( + target=self._nixl_handshake_listener, + args=(metadata, ready_event, self.rank), + daemon=True, + name="nixl_handshake_listener") + self._nixl_handshake_listener_t.start() + ready_event.wait() + + def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): + engine_id = nixl_agent_meta.engine_id + if engine_id in self._remote_agents: + return + + self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent( + nixl_agent_meta.agent_metadata) + self.kv_caches_base_addr[ + engine_id] = nixl_agent_meta.kv_caches_base_addr + + # Create src descs and xfer side handles. + blocks_data = [] + for base_addr in self.kv_caches_base_addr[self.engine_id]: + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + # (addr, len, device id) + blocks_data.append( + (base_addr + block_offset, self.block_len, self.rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", + len(blocks_data), self.engine_id, self.rank) + + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs) + + # Create dst descs and xfer side handles. + self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + blocks_data = [] + for base_addr in self.kv_caches_base_addr[engine_id]: + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * self.block_len + # (addr, len, device id) + blocks_data.append( + (base_addr + block_offset, self.block_len, self.rank)) + logger.debug("Created %s blocks for dst engine %s and rank %s", + len(blocks_data), engine_id, self.rank) + + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[ + engine_id] = self.nixl_wrapper.prep_xfer_dlist( + self._remote_agents[engine_id], descs) + + def get_finished(self) -> tuple[set[str], set[str]]: + """ + Get requests that are done sending or recving. + + In TP>1 setup, each rank exchanges KVs with its counterpart + ranks independently. get_finished() runs in a worker creates + the done_sending and done_recving sets that are sent to the + scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs + are done before adding to finished, Ranks 1 to N-1 communicate + to Rank 0 once their transaction is done + Rank 0 returns + finished sets to Scheduler only once all ranks are done. + """ + done_sending = self._get_new_notifs() + done_recving = self._pop_done_transfers(self._recving_transfers) + if len(done_sending) > 0 or len(done_recving) > 0: + logger.debug( + "Rank %s, get_finished: %s requests done sending " + "and %s requests done recving", self.rank, len(done_sending), + len(done_recving)) + + if self.world_size == 1: + return done_sending, done_recving + + # Rank 0: get finished from all other ranks. + if self.rank == 0: + for req_id in done_sending: + self._done_sending_count[req_id] += 1 + for req_id in done_recving: + self._done_recving_count[req_id] += 1 + + # Keep track of how many other ranks have finished. + other_ranks_finished_ids: list[str] = [] + for i in range(1, self.world_size): + other_ranks_finished_ids.extend( + self.tp_group.recv_object(src=i)) + for req_id in other_ranks_finished_ids: + if (req_id in self._done_recving_count + or req_id in self._recving_transfers): + self._done_recving_count[req_id] += 1 + else: + self._done_sending_count[req_id] += 1 + + # Return ids that finished on all ranks to the scheduler. + all_done_recving: set[str] = set() + for req_id in list(self._done_recving_count.keys()): + if self._done_recving_count[req_id] == self.world_size: + del self._done_recving_count[req_id] + all_done_recving.add(req_id) + + all_done_sending: set[str] = set() + for req_id in list(self._done_sending_count.keys()): + if self._done_sending_count[req_id] == self.world_size: + del self._done_sending_count[req_id] + all_done_sending.add(req_id) + + return all_done_sending, all_done_recving + + # Ranks 1 to N-1: send finished ids to Rank 0. + else: + finished_req_ids = list(done_recving.union(done_sending)) + self.tp_group.send_object(finished_req_ids, dst=0) + + # Unused as only Rank 0 results are sent to scheduler. + return done_sending, done_recving + + def _get_new_notifs(self) -> set[str]: + """Get req_ids which got a remote xfer message.""" + + notified_req_ids: set[str] = set() + for req_ids in self.nixl_wrapper.get_new_notifs().values(): + for req_id in req_ids: + assert req_id not in notified_req_ids + notified_req_ids.add(req_id.decode("utf-8")) + return notified_req_ids + + def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: + """ + Pop completed xfers by checking for DONE state. + Args: + transfers: dict of req_id -> list[running_xfer] + Returns: + set of req_ids that have all done xfers + """ + done_req_ids: set[str] = set() + for req_id, handles in list(transfers.items()): + running_reqs = [] + for handle in handles: + xfer_state = self.nixl_wrapper.check_xfer_state(handle) + if xfer_state == "DONE": + # TODO ptarasiewicz: why abort is throwing errors? + # self.nixl_wrapper.release_xfer_handle(handle) + continue + if xfer_state == "PROC": + running_reqs.append(handle) + else: + raise RuntimeError("Transfer failed with state %s", + xfer_state) + if len(running_reqs) == 0: + done_req_ids.add(req_id) + del transfers[req_id] + else: + transfers[req_id] = running_reqs + return done_req_ids + + def start_load_kv(self, metadata: NixlConnectorMetadata): + """ + Start loading by triggering non-blocking nixl_xfer. + We check for these trnxs to complete in each step(). + """ + for req_id, meta in metadata.requests.items(): + logger.debug( + "start_load_kv for request %s from remote engine %s. " + "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, + meta.remote_engine_id, len(meta.local_block_ids), + len(meta.remote_block_ids)) + self._read_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + remote_host=meta.remote_host, + remote_port=meta.remote_port, + ) + + def _read_blocks( + self, + local_block_ids: list[int], + remote_block_ids: list[int], + remote_host: str, + remote_port: int, + dst_engine_id: str, + request_id: str, + ): + # NOTE(rob): this takes ~2s. We need to get this off the hotpath. + if dst_engine_id not in self._remote_agents: + self._nixl_handshake(remote_host, remote_port) + + # NOTE(rob): having the staging blocks be on the READER side is + # not going to work well (since we will have to call rearrange tensors). + # after we detect the txn is complete (which means we cannot make the + # read trxn async easily). If we want to make "READ" happen cleanly, + # then we will need to have the staging blocks on the remote side. + + # NOTE(rob): according to nvidia the staging blocks are used to + # saturate IB with heterogeneous TP sizes. We should remove the staging + # blocks until we are ready. + + # Full prefix cache hit: do not need to read remote blocks, + # just notify P worker that we have the blocks we need. + num_local_blocks = len(local_block_ids) + if num_local_blocks == 0: + self.nixl_wrapper.send_notif(dst_engine_id, + notif_msg=request_id.encode("utf-8")) + return + + # Partial prefix cache hit: just read uncomputed blocks. + num_remote_blocks = len(remote_block_ids) + assert num_local_blocks <= num_remote_blocks + if num_local_blocks < num_remote_blocks: + remote_block_ids = remote_block_ids[-num_local_blocks:] + + # Get side handles. + local_xfer_side_handle = self.src_xfer_side_handle + remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] + + # Get descs ids. + remote_block_descs_ids = self._get_block_descs_ids( + dst_engine_id, remote_block_ids) + local_block_descs_ids = self._get_block_descs_ids( + self.engine_id, local_block_ids) + assert len(local_block_descs_ids) == len(remote_block_descs_ids) + + # Prepare transfer with Nixl. + handle = self.nixl_wrapper.make_prepped_xfer( + "READ", + local_xfer_side_handle, + local_block_descs_ids, + remote_xfer_side_handle, + remote_block_descs_ids, + notif_msg=request_id.encode("utf-8"), + ) + + # Begin async xfer. + self.nixl_wrapper.transfer(handle) + + # Use handle to check completion in future step(). + self._recving_transfers[request_id].append(handle) + + def _get_block_descs_ids(self, engine_id: str, + block_ids: list[int]) -> list[int]: + """Get the descs ids for a set of block ids.""" + + # range(1) for MLA, range(2) otherwise. + region_ids = range(self.num_regions) + num_blocks = self.dst_num_blocks[engine_id] + + # Compute the desc ids for each block. + descs_ids: list[int] = [] + for reg_id in region_ids: + for block_id in block_ids: + descs_ids.append(reg_id * num_blocks + block_id) + return descs_ids + + +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: + """Context manager for a ZMQ socket""" + + ctx: Optional[zmq.Context] = None + try: + ctx = zmq.Context() # type: ignore[attr-defined] + + if socket_type == zmq.ROUTER: + socket = ctx.socket(zmq.ROUTER) + socket.bind(addr) + elif socket_type == zmq.REQ: + socket = ctx.socket(zmq.REQ) + socket.connect(addr) + else: + raise ValueError(f"Unexpected socket type: {socket_type}") + + yield socket + finally: + if ctx is not None: + ctx.destroy(linger=0) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index f91ffbc720e7..0fedb6fd5ed9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request logger = init_logger(__name__) @@ -132,8 +133,7 @@ def inject_kv_into_layer( dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) # Get the metadata - metadata: KVConnectorMetadata = \ - self._get_connector_metadata() + metadata: KVConnectorMetadata = self._get_connector_metadata() assert isinstance(metadata, SharedStorageConnectorMetadata) if metadata is None: @@ -225,7 +225,7 @@ def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, - ) -> int: + ) -> tuple[int, bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. @@ -239,7 +239,6 @@ def get_num_new_matched_tokens( the number of tokens that can be loaded from the external KV cache beyond what is already computed. """ - # NOTE: in this debug implementation, we assume that the prompt is # cached_prompt + newly_generated_single_token # Therefore, we use prompt_token_ids[:-1] to determine the folder name @@ -248,7 +247,7 @@ def get_num_new_matched_tokens( # with the block granularity. And it expects the returned blocks and # num_computed_tokens to also be aligned with the block granularity. if not self._found_match_for_request(request): - return 0 + return 0, False logger.info("External Cache Hit!") @@ -257,9 +256,10 @@ def get_num_new_matched_tokens( num_tokens_to_check = align_to_block_size( len(request.prompt_token_ids) - 1, self._block_size) - return num_tokens_to_check - num_computed_tokens + return num_tokens_to_check - num_computed_tokens, False def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", num_external_tokens: int): """ Update KVConnector state after block allocation. diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 4e09240f23af..09b0be5c21e2 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -402,6 +402,9 @@ class ChatCompletionRequest(OpenAIBaseModel): "access by 3rd parties, and long enough to be " "unpredictable (e.g., 43 characters base64-encoded, corresponding " "to 256 bit). Not supported by vLLM engine V0.")) + kv_transfer_params: Optional[dict[str, Any]] = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.") # doc: end-chat-completion-extra-params @@ -539,7 +542,9 @@ def to_sampling_params( output_kind=RequestOutputKind.DELTA if self.stream \ else RequestOutputKind.FINAL_ONLY, guided_decoding=guided_decoding, - logit_bias=self.logit_bias) + logit_bias=self.logit_bias, + extra_args=({"kv_transfer_params": self.kv_transfer_params} + if self.kv_transfer_params else None)) def _get_guided_json_from_tool( self) -> Optional[Union[str, dict, BaseModel]]: @@ -847,6 +852,10 @@ class CompletionRequest(OpenAIBaseModel): " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + kv_transfer_params: Optional[dict[str, Any]] = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.") + # doc: end-completion-extra-params # Default sampling parameters for completion requests @@ -972,7 +981,9 @@ def to_sampling_params( else RequestOutputKind.FINAL_ONLY, guided_decoding=guided_decoding, logit_bias=self.logit_bias, - allowed_token_ids=self.allowed_token_ids) + allowed_token_ids=self.allowed_token_ids, + extra_args=({"kv_transfer_params": self.kv_transfer_params} + if self.kv_transfer_params else None)) @model_validator(mode="before") @classmethod @@ -1222,6 +1233,8 @@ class CompletionResponse(OpenAIBaseModel): model: str choices: list[CompletionResponseChoice] usage: UsageInfo + kv_transfer_params: Optional[dict[str, Any]] = Field( + default=None, description="KVTransfer parameters.") class CompletionResponseStreamChoice(OpenAIBaseModel): @@ -1411,6 +1424,8 @@ class ChatCompletionResponse(OpenAIBaseModel): choices: list[ChatCompletionResponseChoice] usage: UsageInfo prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None + kv_transfer_params: Optional[dict[str, Any]] = Field( + default=None, description="KVTransfer parameters.") class DeltaMessage(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5c11836fbff4..368a7e0864e9 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1075,6 +1075,7 @@ async def chat_completion_full_generator( choices=choices, usage=usage, prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), + kv_transfer_params=final_res.kv_transfer_params, ) return response diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 1067f35ce240..0b3bdf7d4821 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -482,7 +482,7 @@ def request_output_to_completion_response( model=model_name, choices=choices, usage=usage, - ) + kv_transfer_params=final_res_batch[0].kv_transfer_params) def _create_completion_logprobs( self, diff --git a/vllm/envs.py b/vllm/envs.py index d7f332cb0a73..b3faad03d345 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -112,6 +112,8 @@ VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False + VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" + VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 def get_default_cache_root(): @@ -747,6 +749,14 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # insecure method and it is needed for some reason. "VLLM_ALLOW_INSECURE_SERIALIZATION": lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))), + + # IP address used for NIXL handshake between remote agents. + "VLLM_NIXL_SIDE_CHANNEL_HOST": + lambda: os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST", "localhost"), + + # Port used for NIXL handshake between remote agents. + "VLLM_NIXL_SIDE_CHANNEL_PORT": + lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")), } # end-env-vars-definition diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 9ddc3d1f2c51..eb1e1f5694bb 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -11,10 +11,6 @@ import vllm.envs as envs from vllm.config import VllmConfig -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group, - is_v1_kv_transfer_group) -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.logger import init_logger if TYPE_CHECKING: @@ -106,16 +102,6 @@ def set_forward_context(attn_metadata: Any, attn_metadata=attn_metadata, dp_metadata=dp_metadata) - # KVConnector: trigger (possibly async) load before forward. - # Each attn layer will block until the reading is complete. - trigger_kv_transfer = (attn_metadata is not None - and has_kv_transfer_group() - and is_v1_kv_transfer_group()) - if trigger_kv_transfer: - kv_connector = get_kv_transfer_group() - assert isinstance(kv_connector, KVConnectorBase_V1) - kv_connector.start_load_kv(_forward_context) - try: yield finally: @@ -152,11 +138,4 @@ def set_forward_context(attn_metadata: Any, "(batchsize, count, median_time(ms)): %s"), forward_stats) - # KVConnector: each attn layer triggers (possibly async) save. - # Ensure all those operations complete before forward() is done. - if trigger_kv_transfer: - kv_connector = get_kv_transfer_group() - assert isinstance(kv_connector, KVConnectorBase_V1) - kv_connector.wait_for_save() - _forward_context = prev_context diff --git a/vllm/outputs.py b/vllm/outputs.py index 65a6ed01451d..6cd60575b00d 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -4,7 +4,7 @@ from collections.abc import MutableSequence from collections.abc import Sequence as GenericSequence from dataclasses import dataclass -from typing import Generic, Optional, Union +from typing import Any, Generic, Optional, Union import torch from typing_extensions import TypeVar, deprecated @@ -103,6 +103,7 @@ class RequestOutput: encoder_prompt_token_ids: The token IDs of the encoder prompt. None if decoder-only. num_cached_tokens: The number of tokens with prefix cache hit. + kv_transfer_params: The params for remote K/V transfer. """ def __init__( @@ -120,6 +121,7 @@ def __init__( num_cached_tokens: Optional[int] = None, *, multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, + kv_transfer_params: Optional[dict[str, Any]] = None, ) -> None: self.request_id = request_id self.prompt = prompt @@ -133,11 +135,13 @@ def __init__( self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens + self.kv_transfer_params = kv_transfer_params def add(self, next_output: "RequestOutput", aggregate: bool) -> None: """Merge subsequent RequestOutput into this one""" self.finished |= next_output.finished + self.kv_transfer_params = next_output.kv_transfer_params for next_completion in next_output.outputs: for i, completion in enumerate(self.outputs): diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index ad8468a89dc5..27368374ea8d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -36,6 +36,12 @@ def get_block_ids(self) -> list[int]: """Converts the KVCacheBlocks instance to a list of block IDs.""" return [block.block_id for block in self.blocks] + def get_unhashed_block_ids(self) -> list[int]: + """Get block_ids of unhashed blocks from KVCacheBlocks instance.""" + return [ + block.block_id for block in self.blocks if block.block_hash is None + ] + class KVCacheManager: @@ -116,6 +122,12 @@ def get_computed_blocks(self, - The number of computed tokens. """ + # Request already has blocks from async load via KVConnector. + num_existing_blocks = len( + self.single_type_manager.req_to_blocks[request.request_id]) + if num_existing_blocks > 0: + return KVCacheBlocks.create_empty(), request.num_computed_tokens + # Prefix caching is disabled or # When the request requires prompt logprobs, we skip prefix caching. if (not self.enable_caching @@ -173,6 +185,7 @@ def allocate_slots( num_new_tokens: int, new_computed_blocks: Optional[KVCacheBlocks] = None, num_lookahead_tokens: int = 0, + delay_cache_blocks: bool = False, ) -> Optional[KVCacheBlocks]: """Add slots for a request with new tokens to append. @@ -186,6 +199,9 @@ def allocate_slots( num_lookahead_tokens: The number of speculative tokens to allocate. This is used by spec decode proposers with kv-cache such as eagle. + delay_cache_blocks: Whether to skip caching the blocks. This is + used by P/D when allocating blocks used in a KV transfer + which will complete in a future step. Blocks layout: ``` @@ -255,7 +271,9 @@ def allocate_slots( new_blocks = self.single_type_manager.allocate_new_blocks( request.request_id, num_tokens_need_slot) - if not self.enable_caching: + # P/D: delay caching blocks if we have to recv from + # remote. Update state for locally cached blocks. + if not self.enable_caching or delay_cache_blocks: return KVCacheBlocks(new_blocks) # Speculated tokens might be rejected in the future, so we does @@ -350,3 +368,16 @@ def take_events(self) -> list[KVCacheEvent]: A list of KV cache events. """ return self.block_pool.take_events() + + def get_block_ids(self, request_id: str) -> list[int]: + """Get the block ids of a request.""" + assert request_id in self.single_type_manager.req_to_blocks + return [ + block.block_id + for block in self.single_type_manager.req_to_blocks[request_id] + ] + + def get_num_blocks(self, request_id: str): + """Get the number of blocks.""" + assert request_id in self.single_type_manager.req_to_blocks + return len(self.single_type_manager.req_to_blocks[request_id]) diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 0b328f510903..c17f80b6ae78 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional, Union if TYPE_CHECKING: + from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import EngineCoreOutputs from vllm.v1.metrics.stats import SchedulerStats @@ -137,3 +138,6 @@ def make_stats(self) -> Optional["SchedulerStats"]: def shutdown(self) -> None: """Shutdown the scheduler.""" raise NotImplementedError + + def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]: + return None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 258e0d570e3e..7773853b096a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -5,13 +5,15 @@ import time from collections import defaultdict, deque from collections.abc import Iterable -from typing import Optional, Union +from typing import Any, Optional, Union from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole, + KVTransferParams) from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, @@ -96,6 +98,9 @@ def __init__( # This is flushed at the end of each scheduling step. self.finished_req_ids: set[str] = set() + # P/D: requests in process of recving KV transfers + self.finished_recving_kv_req_ids: set[str] = set() + # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. # Request id -> deque of CachedRequestData @@ -307,6 +312,16 @@ def schedule(self) -> SchedulerOutput: request = self.waiting[0] + # P/D: skip request if still waiting for remote kvs. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + is_ready = self._update_waiting_for_remote_kv(request) + if is_ready: + request.status = RequestStatus.WAITING + else: + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + # Skip request if the structured output request is still waiting # for FSM compilation. if request.status == RequestStatus.WAITING_FOR_FSM: @@ -330,49 +345,55 @@ def schedule(self) -> SchedulerOutput: continue # Get already-cached tokens. - computed_blocks, num_computed_tokens = \ + new_computed_blocks, num_computed_tokens = \ self.kv_cache_manager.get_computed_blocks( request) # Get externally-cached tokens if using a KVConnector. - num_external_tokens = ( - 0 if self.connector is None else + num_external_tokens, load_kv_async = ( + (0, False) if self.connector is None else self.connector.get_num_new_matched_tokens( request, num_computed_tokens)) # Total computed tokens (local + external). num_computed_tokens += num_external_tokens + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + + # P/D: loading remote KV, do not allocate for new work. + if load_kv_async: + num_new_tokens = 0 # Number of tokens to be scheduled. - # We use `request.num_tokens` instead of - # `request.num_prompt_tokens` to consider the resumed requests, - # which have output tokens. - num_new_tokens = request.num_tokens - num_computed_tokens - if (0 < self.scheduler_config.long_prefill_token_threshold < - num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) - num_new_tokens = min(num_new_tokens, token_budget) - assert num_new_tokens > 0 - - # Schedule encoder inputs. - if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget) = self._try_schedule_encoder_inputs( - request, num_computed_tokens, num_new_tokens, - encoder_budget) - if num_new_tokens == 0: - # The request cannot be scheduled. - break else: - encoder_inputs_to_schedule = None - new_encoder_budget = encoder_budget + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed + # requests, which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + if (0 < self.scheduler_config.long_prefill_token_threshold + < num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget + ) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_budget) + if num_new_tokens == 0: + # The request cannot be scheduled. + break new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens + num_external_tokens, - computed_blocks, + new_computed_blocks, num_lookahead_tokens=self.num_lookahead_tokens, + delay_cache_blocks=load_kv_async, ) if new_blocks is None: # The request cannot be scheduled. @@ -384,10 +405,18 @@ def schedule(self) -> SchedulerOutput: if self.connector is not None: self.connector.update_state_after_alloc( request, + new_computed_blocks + new_blocks, num_external_tokens, ) self.waiting.popleft() + if load_kv_async: + # If loading async, allocate memory and put request + # into the WAITING_FOR_REMOTE_KV state. + skipped_waiting_requests.appendleft(request) + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + continue + if request.use_structured_output: structured_output_request_ids[ request.request_id] = req_index @@ -407,7 +436,7 @@ def schedule(self) -> SchedulerOutput: if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) req_to_new_block_ids[request.request_id] = ( - computed_blocks + new_blocks).get_block_ids() + self.kv_cache_manager.get_block_ids(request.request_id)) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -698,6 +727,7 @@ def update_from_output( stopped = False new_logprobs = None new_token_ids = generated_token_ids + kv_transfer_params = None # Append generated tokens and check for stop. Note that if # a request is still being prefilled, we expect the model runner @@ -709,7 +739,7 @@ def update_from_output( # This must be called before we make the EngineCoreOutput. stopped = check_stop(request, self.max_model_len) if stopped: - self._free_request(request) + kv_transfer_params = self._free_request(request) del new_token_ids[num_new:] # Trim new tokens if needed. break @@ -739,7 +769,8 @@ def update_from_output( # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids: + if new_token_ids or kv_transfer_params: + # Add EngineCoreOutput for this Request. outputs.append( EngineCoreOutput( @@ -749,7 +780,10 @@ def update_from_output( new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, - events=request.take_events())) + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + )) + else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors @@ -757,6 +791,9 @@ def update_from_output( if not stopped: new_running.append(request) + # P/D: update state for finished KV Transfers. + self._update_from_kv_xfer_finished(model_runner_output) + # Return the cached request data to the queue so they can be reused. for req_data in scheduler_output.scheduled_cached_reqs: # NOTE(rob): since we free stopped reqs above, adding stopped reqs @@ -811,15 +848,27 @@ def finish_requests( request.status = finished_status self._free_request(request) - def _free_request(self, request: Request) -> None: + def _free_request(self, request: Request) -> Optional[dict[str, Any]]: + assert request.is_finished() - self.kv_cache_manager.free(request) - self.kv_cache_manager.free_block_hashes(request) + + delay_free_blocks, kv_xfer_params = self._connector_finished(request) self.encoder_cache_manager.free(request) self._cached_reqs_data.pop(request.request_id, None) - del self.requests[request.request_id] self.finished_req_ids.add(request.request_id) + if not delay_free_blocks: + self._free_blocks(request) + + return kv_xfer_params + + def _free_blocks(self, request: Request): + assert request.is_finished() + assert request.request_id not in self._cached_reqs_data + self.kv_cache_manager.free(request) + self.kv_cache_manager.free_block_hashes(request) + del self.requests[request.request_id] + def get_num_unfinished_requests(self) -> int: return len(self.waiting) + len(self.running) @@ -863,3 +912,70 @@ def make_spec_decoding_stats( def shutdown(self) -> None: if self.kv_event_publisher: self.kv_event_publisher.shutdown() + + ######################################################################## + # P/D Related Methods + ######################################################################## + + def get_kv_connector(self) -> Optional[KVConnectorBase_V1]: + return self.connector + + def _connector_finished( + self, request: Request) -> tuple[bool, Optional[KVTransferParams]]: + """Invoke the KV connector request_finished() method if applicable.""" + if self.connector is None: + return False, None + block_ids = self.kv_cache_manager.get_block_ids(request.request_id) + return self.connector.request_finished(request, block_ids) + + def _update_waiting_for_remote_kv(self, request: Request) -> bool: + """ + P/D: check if the request_id is finished_recving. + + The finished_recving_kv_req_ids list is populated + on the previous steps()'s update_from_output based + on the worker side connector. + + When the kv transfer is ready, we cache the blocks + and the request state will be moved back to WAITING from + WAITING_FOR_REMOTE_KV. + """ + if request.request_id not in self.finished_recving_kv_req_ids: + return False + + # Now that the blocks are ready, actually cache them. + block_ids = self.kv_cache_manager.get_block_ids(request.request_id) + num_computed_tokens = len(block_ids) * self.block_size + if num_computed_tokens == request.num_tokens: + num_computed_tokens -= 1 + self.kv_cache_manager.single_type_manager.cache_blocks( + request, + self.kv_cache_manager.req_to_block_hashes[request.request_id], + num_computed_tokens, + ) + + # Update the request state for scheduling. + request.num_computed_tokens = num_computed_tokens + + # Return that we are ready. + self.finished_recving_kv_req_ids.remove(request.request_id) + return True + + def _update_from_kv_xfer_finished(self, + model_runner_output: ModelRunnerOutput): + """ + P/D: update the scheduler state based on the output. + + The Worker side connectors add finished_recving and + finished_sending reqs to the output. + * if finished_sending: free the blocks + # if finished_recving: add to state so we can + scheduler the request during the next step. + """ + # P/D: update recv and send status from last step. + for req_id in (model_runner_output.finished_recving or ()): + logger.debug("Finished recving KV transfer for request %s", req_id) + self.finished_recving_kv_req_ids.add(req_id) + for req_id in (model_runner_output.finished_sending or ()): + logger.debug("Finished sending KV transfer for request %s", req_id) + self._free_blocks(self.requests[req_id]) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index e33d1a1e5dcd..122a5a72cc36 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -105,6 +105,7 @@ class EngineCoreOutput( finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None events: Optional[list[EngineCoreEvent]] = None + kv_transfer_params: Optional[dict[str, Any]] = None @property def finished(self) -> bool: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d9dd4957cff2..c1aa0ce27d3f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -182,6 +182,15 @@ def add_request(self, request: EngineCoreRequest): # Start grammar compilation asynchronously self.structured_output_manager.grammar_init(req) + if req.raw_kv_transfer_params is not None: + if (kv_connector := self.scheduler.get_kv_connector()): + # Parse raw KV transfer params via connector. + kv_connector.set_kv_transfer_params(req) + else: + logger.warning( + "Got KVTransferParams, but no KVConnector found. " + "Disabling KVTransfer for this request.") + self.scheduler.add_request(req) def abort_requests(self, request_ids: list[str]): diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 5f5ffe6e09db..a7a9b0e4a161 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import Iterable from dataclasses import dataclass -from typing import Optional, Union +from typing import Any, Optional, Union from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind @@ -146,6 +146,7 @@ def make_request_output( new_token_ids: list[int], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], + kv_transfer_params: Optional[dict[str, Any]] = None, ) -> Optional[RequestOutput]: finished = finish_reason is not None @@ -167,13 +168,15 @@ def make_request_output( if not outputs: return None - return self._new_request_output(request_id, outputs, finished) + return self._new_request_output(request_id, outputs, finished, + kv_transfer_params) def _new_request_output( self, request_id: str, outputs: list[CompletionOutput], finished: bool, + kv_transfer_params: Optional[dict[str, Any]] = None, ) -> RequestOutput: if self.output_kind == RequestOutputKind.DELTA: @@ -189,6 +192,7 @@ def _new_request_output( prompt_logprobs=prompt_logprobs, outputs=outputs, finished=finished, + kv_transfer_params=kv_transfer_params, ) def _new_completion_output( @@ -335,6 +339,7 @@ def process_outputs( new_token_ids = engine_core_output.new_token_ids finish_reason = engine_core_output.finish_reason stop_reason = engine_core_output.stop_reason + kv_transfer_params = engine_core_output.kv_transfer_params req_state.is_prefilling = False @@ -350,7 +355,8 @@ def process_outputs( # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( - new_token_ids, finish_reason, stop_reason): + new_token_ids, finish_reason, stop_reason, + kv_transfer_params): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 2732b933c28a..e8ce0df5ed8d 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -100,12 +100,16 @@ class ModelRunnerOutput: # [prompt_len] prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] - -EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( - req_ids=[], - req_id_to_index={}, - sampled_token_ids=[], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, -) + # [req_ids] + finished_sending: Optional[set[str]] = None + finished_recving: Optional[set[str]] = None + + +EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + finished_sending=None, + finished_recving=None) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index fde366d61c7d..fc6b738546f4 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 import enum -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union +from vllm.distributed.kv_transfer.kv_connector.v1 import KVTransferParams from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils import is_list_of @@ -61,6 +62,15 @@ def __init__( self.num_encoder_inputs = len(self.mm_inputs) self.has_encoder_inputs = self.num_encoder_inputs > 0 + # P/D: KV transfer parameters (raw and parsed). + raw_params = (None if sampling_params.extra_args is None + else sampling_params.extra_args.get( + "kv_transfer_params", None)) + self.raw_kv_transfer_params: Optional[dict[str, Any]] = raw_params + # Each connector parses the raw dictionary and sets this + # attr the first time that the request is processed. + self.kv_transfer_params: Optional[KVTransferParams] = None + # Sanity check assert len(self.mm_inputs) == len(self.mm_positions) if self.mm_hashes: @@ -150,6 +160,7 @@ class RequestStatus(enum.IntEnum): """Status of a request.""" WAITING = enum.auto() WAITING_FOR_FSM = enum.auto() + WAITING_FOR_REMOTE_KVS = enum.auto() RUNNING = enum.auto() PREEMPTED = enum.auto() # Note: anything after PREEMPTED will be considered diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fdb1339cddca..bd833735b695 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import copy import gc import time import weakref @@ -17,8 +18,9 @@ get_layers_from_vllm_config) from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import get_pp_group, graph_capture -from vllm.forward_context import set_forward_context +from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model @@ -1065,15 +1067,14 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, IntermediateTensors]: - # Update KVConnector with the KVConnector metadata forward(). - if has_kv_transfer_group(): - get_kv_transfer_group().bind_connector_metadata( - scheduler_output.kv_connector_metadata) self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + + return self.kv_connector_no_forward(scheduler_output) # Prepare the decoder inputs. attn_metadata, logits_indices, spec_decode_metadata = ( @@ -1150,17 +1151,23 @@ def execute_model( with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_input_tokens): - output = self.model( + self.maybe_setup_kv_connector(scheduler_output) + + model_output = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) + if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = output + hidden_states, aux_hidden_states = model_output else: - hidden_states = output + hidden_states = model_output if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. @@ -1341,8 +1348,56 @@ def execute_model( spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, + finished_sending=finished_sending, + finished_recving=finished_recving, ) + def kv_connector_no_forward( + self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: + # KV send/recv even if no work to do. + with set_forward_context(None, self.vllm_config): + self.maybe_setup_kv_connector(scheduler_output) + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) + + if not finished_sending and not finished_recving: + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.finished_sending = finished_sending + output.finished_recving = finished_recving + return output + + @staticmethod + def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata) + + # Background KV cache transfers happen here. + # These transfers are designed to be async and the requests + # involved may be disjoint from the running requests. + # Do this here to save a collective_rpc. + kv_connector.start_load_kv(get_forward_context()) + + @staticmethod + def maybe_wait_for_kv_save() -> None: + if has_kv_transfer_group(): + get_kv_transfer_group().wait_for_save() + + @staticmethod + def get_finished_kv_transfers( + scheduler_output: "SchedulerOutput", + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + if has_kv_transfer_group(): + return get_kv_transfer_group().get_finished( + scheduler_output.finished_req_ids) + return None, None + def generate_draft_token_ids( self, sampled_token_ids: list[list[int]], @@ -1813,6 +1868,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.vllm_config.compilation_config.static_forward_context, self.kv_caches) + if has_kv_transfer_group(): + get_kv_transfer_group().register_kv_caches(kv_caches) + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( weakref.proxy(self), kv_cache_config.kv_cache_groups[0].kv_cache_spec,