Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions components/backends/vllm/src/dynamo/vllm/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, component, engine, default_sampling_params):
self.kv_publisher = None

@abstractmethod
async def generate(self, request) -> AsyncGenerator[dict, None]:
async def generate(self, request, context) -> AsyncGenerator[dict, None]:
raise NotImplementedError

async def clear_kv_blocks(self, request=None):
Expand Down Expand Up @@ -110,7 +110,7 @@ def cleanup(self):
self._prefill_check_task.cancel()
super().cleanup()

async def generate(self, request):
async def generate(self, request, context):
request_id = str(uuid.uuid4().hex)
logger.debug(f"New Request ID: {request_id}")

Expand Down Expand Up @@ -147,9 +147,20 @@ async def generate(self, request):

# TODO Change to prefill queue
if self.prefill_worker_client is not None:
prefill_response = await anext(
await self.prefill_worker_client.round_robin(prefill_request)
)
try:
prefill_response = await anext(
await self.prefill_worker_client.round_robin(
prefill_request, context
)
)
except Exception as e:
# TODO: Cancellation does not propagate until the first token is received
if context.is_stopped() or context.is_killed():
logger.debug(f"Aborted Remote Prefill Request ID: {request_id}")
# TODO: Raise asyncio.CancelledError into bindings
return
raise e

prefill_response = MyRequestOutput.model_validate_json(
prefill_response.data()
)
Expand All @@ -162,14 +173,20 @@ async def generate(self, request):
] = prefill_response.kv_transfer_params

async for tok in self.generate_tokens(prompt, sampling_params, request_id):
if context.is_stopped() or context.is_killed():
await self.engine_client.abort(request_id)
logger.debug(f"Aborted Request ID: {request_id}")
# TODO: Raise asyncio.CancelledError into bindings
break

yield tok


class PrefillWorkerHandler(BaseWorkerHandler):
def __init__(self, component, engine, default_sampling_params):
super().__init__(component, engine, default_sampling_params)

async def generate(self, request):
async def generate(self, request, context):
request_id = request["request_id"]
logger.debug(f"New Prefill Request ID: {request_id}")

Expand All @@ -181,6 +198,12 @@ async def generate(self, request):
# Generate only 1 token in prefill
try:
async for res in gen:
if context.is_stopped() or context.is_killed():
await self.engine_client.abort(request_id)
logger.debug(f"Aborted Prefill Request ID: {request_id}")
# TODO: Raise asyncio.CancelledError into bindings
break

logger.debug(f"kv transfer params: {res.kv_transfer_params}")
yield MyRequestOutput(
request_id=res.request_id,
Expand Down
Loading