Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cleanup_pr_body.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:

steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0

- name: Set up Python
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/lint-and-deploy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
fetch-depth: 0

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: "3.12"
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
upload_url: ${{ steps.create_release.outputs.upload_url }}
steps:
- name: Checkout
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0

- name: Extract branch info
shell: bash
Expand Down Expand Up @@ -55,7 +55,7 @@ jobs:

# steps:
# - name: Checkout
# uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
# uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0

# - name: Setup ccache
# uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14
Expand Down
35 changes: 35 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,14 @@
logger.info(("init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"), elapsed)

self.max_kv_cache_size = num_gpu_blocks * self.cache_config.block_size


def get_max_kv_cache_size(self) -> int:
"""Get the maximum size of the KV cache."""
return self.max_kv_cache_size


@classmethod
def _get_executor_cls(cls,
engine_config: VllmConfig) -> Type[ExecutorBase]:
Expand Down Expand Up @@ -559,6 +567,7 @@
return None

self._validate_model_inputs(processed_inputs, lora_request)
self._compute_free_tokens(processed_inputs, params)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
Expand Down Expand Up @@ -1842,8 +1851,34 @@

# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens

Check failure on line 1854 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/engine/llm_engine.py:1854:50: F821 Undefined name `LLMInputs`

Check failure on line 1855 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/engine/llm_engine.py:1855:50: F821 Undefined name `EncoderDecoderLLMInputs`
def _compute_free_tokens(
self, inputs: Union[LLMInputs, EncoderDecoderLLMInputs], params
):
if self.model_config.is_multimodal_model:
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_ids = inputs.get("prompt_token_ids")
elif self.is_encoder_decoder_model():
prompt_ids = inputs.get("encoder_prompt_token_ids")
else:
prompt_ids = inputs.get("prompt_token_ids")

input_token_count = len(prompt_ids)
max_token_count = self.model_config.max_model_len

max_output_tokens = 0
if isinstance(params, SamplingParams):
max_output_tokens = params.max_tokens

free_tokens = max_token_count - max_output_tokens - input_token_count

Check failure on line 1875 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/engine/llm_engine.py:1875:21: G004 Logging statement uses f-string

logger.info(f"Free tokens available after tokenization: {free_tokens}")

if free_tokens > 0:
self.send_free_tokens_callback(free_tokens)

def _build_logits_processors(
self, sampling_params: SamplingParams,
lora_request: Optional[LoRARequest]) -> SamplingParams:
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
IPC_OUTPUT_EXT = "_output_socket"
IPC_HEALTH_EXT = "_health_socket"
IPC_DATA_EXT = "_data_socket"
IPC_TOKENS_EXT = "_tokens_socket"


class MQEngineDeadError(RuntimeError):
Expand Down Expand Up @@ -72,6 +73,7 @@ class RPCStartupRequest(Enum):
@dataclass
class RPCStartupResponse:
tracing_enabled: bool
max_kv_cache_size: int


class RPCUProfileRequest(Enum):
Expand Down Expand Up @@ -116,6 +118,10 @@ class RPCLoadAdapterRequest:
# Set the default value of request_id to a new UUID
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))

@dataclass
class FreeTokensRequest:
free_token_count: int


@dataclass
class RPCAdapterLoadedResponse:
Expand Down
61 changes: 61 additions & 0 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
IPC_TOKENS_EXT,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError,
RPCIsSleepingRequest,
Expand Down Expand Up @@ -120,6 +121,10 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig,
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")

# IPC path for acking tokens.
self.tokens_socket: Socket = self.context.socket(zmq.constants.PULL)
self.tokens_socket.connect(f"{ipc_path}{IPC_TOKENS_EXT}")

# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"

Expand All @@ -131,6 +136,16 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig,
# build the Client in an executor to enable clean shutdown.
self.output_loop: Optional[asyncio.Task] = None

# Loop to handle tokens from the LLMEngine periodically.
# Started after the MQLLMEngine is ready so that we can
# build the Client in an executor to enable clean shutdown.
self.token_loop: Optional[asyncio.Task] = None

self.max_kv_cache_size : Optional[int] = None
self.current_kv_cache_size : Optional[int] = None

self.kv_cache_size_updated_event = asyncio.Event()

# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None
Expand Down Expand Up @@ -185,6 +200,42 @@ async def run_heartbeat_loop(self, timeout: int):
except Exception as e:
self._set_errored(e)

async def run_token_handler_loop(self):
"""As kv cache token slots become free, they are pushed to the
tokens_socket. This loop listens to the tokens_socket and
updates the kv cache size.
"""
try:
while True:
# Poll, checking for ENGINE_DEAD
if await self.tokens_socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
logger.debug("Waiting for tokens from MQLLMEngine.")

# If errored, alert all running requests.
if self.errored:
for queue_j in tuple(self.output_queues.values()):
queue_j.put_nowait(
ENGINE_DEAD_ERROR(self._errored_with))
return

message: Frame = await self.tokens_socket.recv(copy=False)
tokens_response = pickle.loads(message.buffer)

if isinstance(tokens_response, BaseException):
raise tokens_response

# Update the kv cache size.
self.current_kv_cache_size += tokens_response.freed_tokens
self.current_kv_cache_size = min(
self.current_kv_cache_size, self.max_kv_cache_size)

# Set the event to notify that the kv cache size has
# been updated.
self.kv_cache_size_updated_event.set()

except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient token handler loop.")

async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues"""

Expand Down Expand Up @@ -288,11 +339,19 @@ async def setup(self):

self.tracing_flag = response.tracing_enabled

self.max_kv_cache_size = response.max_kv_cache_size
self.current_kv_cache_size = self.max_kv_cache_size

# Start health_loop.
if self.health_loop is None:
self.health_loop = asyncio.create_task(
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))

# Start token handler loop.
if self.token_loop is None:
self.token_loop = asyncio.create_task(
self.run_token_handler_loop())

def close(self):
"""Destroy the ZeroMQ Context."""
# Close all sockets and terminate the context.
Expand All @@ -301,6 +360,8 @@ def close(self):
# Cancel background tasks.
if self.health_loop is not None:
self.health_loop.cancel()
if self.token_loop is not None:
self.token_loop.cancel()
if self.output_loop is not None:
self.output_loop.cancel()

Expand Down
35 changes: 33 additions & 2 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_TOKENS_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError,
Expand All @@ -27,7 +28,8 @@
RPCResetPrefixCacheRequest,
RPCSleepRequest, RPCStartupRequest,
RPCStartupResponse,
RPCUProfileRequest, RPCWakeUpRequest)
RPCUProfileRequest, RPCWakeUpRequest,
FreeTokensRequest)
# yapf: enable
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
Expand Down Expand Up @@ -93,6 +95,9 @@
self.engine.process_request_outputs_callback = \
self._async_socket_engine_callback

self.engine.send_free_tokens_callback = \
self._send_free_tokens_callback

self.ctx = zmq.Context() # type: ignore[attr-defined]

# Receive input from the client.
Expand All @@ -107,6 +112,10 @@
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")

# Send tokens back to client.
self.tokens_socket = self.ctx.socket(zmq.constants.PUSH)
self.tokens_socket.bind(f"{ipc_path}{IPC_TOKENS_EXT}")

# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"

Expand Down Expand Up @@ -209,7 +218,9 @@
if request == RPCStartupRequest.IS_SERVER_READY:
tracing_enabled = self.engine.is_tracing_enabled()
response = RPCStartupResponse(
tracing_enabled=tracing_enabled)
tracing_enabled=tracing_enabled,
max_kv_cache_size=self.engine.get_max_kv_cache_size(),
)

except Exception as e:
response = e
Expand Down Expand Up @@ -411,9 +422,29 @@
def _async_socket_engine_callback(self,
request_outputs: REQUEST_OUTPUTS_T):
"""Callback used by engine to make socket handling async with GPU."""
self._send_free_tokens_for_outputs(request_outputs)
self._send_outputs(request_outputs)
self.handle_new_input()

def _send_free_tokens_for_outputs(self, request_outputs: REQUEST_OUTPUTS_T):
"""Send free tokens for outputs if available."""
free_tokens = 0

Check failure on line 431 in vllm/engine/multiprocessing/engine.py

View workflow job for this annotation

GitHub Actions / pre-commit

unexpected indent [syntax]

Check failure on line 431 in vllm/engine/multiprocessing/engine.py

View workflow job for this annotation

GitHub Actions / pre-commit

unexpected indent [syntax]

Check failure on line 431 in vllm/engine/multiprocessing/engine.py

View workflow job for this annotation

GitHub Actions / pre-commit

unexpected indent [syntax]

Check failure on line 431 in vllm/engine/multiprocessing/engine.py

View workflow job for this annotation

GitHub Actions / pre-commit

unexpected indent [syntax]

Check failure on line 431 in vllm/engine/multiprocessing/engine.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff

vllm/engine/multiprocessing/engine.py:431:1: SyntaxError: Unexpected indentation
for output in request_outputs:
if not isinstance(output, RequestOutput):
continue
free_tokens += len(output.prompt_token_ids)
free_tokens += output.max_tokens

if free_tokens > 0:
self._send_free_tokens_callback(free_tokens)

def _send_free_tokens_callback(self, free_tokens: int):
"""Callback used by engine to send free tokens to the client."""
if not self.tokens_socket.closed:
self.tokens_socket.send_multipart(
(pickle.dumps(FreeTokensRequest(free_token_count=free_tokens)),), copy=False

Check failure on line 445 in vllm/engine/multiprocessing/engine.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/engine/multiprocessing/engine.py:445:81: E501 Line too long (92 > 80)
)

def _set_errored(self, e: BaseException):
"""Log and set errored status if this is the first issue."""
if self._errored_with is None:
Expand Down Expand Up @@ -441,7 +472,7 @@
return self.engine.is_sleeping()


def signal_handler(*_) -> None:

Check failure on line 475 in vllm/engine/multiprocessing/engine.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff

vllm/engine/multiprocessing/engine.py:475:1: SyntaxError: Expected a statement
raise KeyboardInterrupt("MQLLMEngine terminated")


Expand Down
Loading