From 80e7169d649147769c5978ae0f0c8f2ba546a073 Mon Sep 17 00:00:00 2001 From: Jacky <18255193+kthui@users.noreply.github.com> Date: Mon, 25 Aug 2025 13:13:53 -0700 Subject: [PATCH] feat: vLLM explicit cancellation --- .../backends/vllm/src/dynamo/vllm/handlers.py | 35 +++++++++++++++---- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/components/backends/vllm/src/dynamo/vllm/handlers.py b/components/backends/vllm/src/dynamo/vllm/handlers.py index 2c4590a898..d374bd24bb 100644 --- a/components/backends/vllm/src/dynamo/vllm/handlers.py +++ b/components/backends/vllm/src/dynamo/vllm/handlers.py @@ -32,7 +32,7 @@ def __init__(self, component, engine, default_sampling_params): self.kv_publisher = None @abstractmethod - async def generate(self, request) -> AsyncGenerator[dict, None]: + async def generate(self, request, context) -> AsyncGenerator[dict, None]: raise NotImplementedError async def clear_kv_blocks(self, request=None): @@ -110,7 +110,7 @@ def cleanup(self): self._prefill_check_task.cancel() super().cleanup() - async def generate(self, request): + async def generate(self, request, context): request_id = str(uuid.uuid4().hex) logger.debug(f"New Request ID: {request_id}") @@ -147,9 +147,20 @@ async def generate(self, request): # TODO Change to prefill queue if self.prefill_worker_client is not None: - prefill_response = await anext( - await self.prefill_worker_client.round_robin(prefill_request) - ) + try: + prefill_response = await anext( + await self.prefill_worker_client.round_robin( + prefill_request, context + ) + ) + except Exception as e: + # TODO: Cancellation does not propagate until the first token is received + if context.is_stopped() or context.is_killed(): + logger.debug(f"Aborted Remote Prefill Request ID: {request_id}") + # TODO: Raise asyncio.CancelledError into bindings + return + raise e + prefill_response = MyRequestOutput.model_validate_json( prefill_response.data() ) @@ -162,6 +173,12 @@ async def generate(self, request): ] = prefill_response.kv_transfer_params async for tok in self.generate_tokens(prompt, sampling_params, request_id): + if context.is_stopped() or context.is_killed(): + await self.engine_client.abort(request_id) + logger.debug(f"Aborted Request ID: {request_id}") + # TODO: Raise asyncio.CancelledError into bindings + break + yield tok @@ -169,7 +186,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): def __init__(self, component, engine, default_sampling_params): super().__init__(component, engine, default_sampling_params) - async def generate(self, request): + async def generate(self, request, context): request_id = request["request_id"] logger.debug(f"New Prefill Request ID: {request_id}") @@ -181,6 +198,12 @@ async def generate(self, request): # Generate only 1 token in prefill try: async for res in gen: + if context.is_stopped() or context.is_killed(): + await self.engine_client.abort(request_id) + logger.debug(f"Aborted Prefill Request ID: {request_id}") + # TODO: Raise asyncio.CancelledError into bindings + break + logger.debug(f"kv transfer params: {res.kv_transfer_params}") yield MyRequestOutput( request_id=res.request_id,