Skip to content

Commit b4603fa

Browse files
authored
feat: vLLM abort on stream stop (#2717)
1 parent e75ca6d commit b4603fa

File tree

1 file changed

+29
-6
lines changed

1 file changed

+29
-6
lines changed

components/backends/vllm/src/dynamo/vllm/handlers.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self, component, engine, default_sampling_params):
3232
self.kv_publisher = None
3333

3434
@abstractmethod
35-
async def generate(self, request) -> AsyncGenerator[dict, None]:
35+
async def generate(self, request, context) -> AsyncGenerator[dict, None]:
3636
raise NotImplementedError
3737

3838
async def clear_kv_blocks(self, request=None):
@@ -110,7 +110,7 @@ def cleanup(self):
110110
self._prefill_check_task.cancel()
111111
super().cleanup()
112112

113-
async def generate(self, request):
113+
async def generate(self, request, context):
114114
request_id = str(uuid.uuid4().hex)
115115
logger.debug(f"New Request ID: {request_id}")
116116

@@ -147,9 +147,20 @@ async def generate(self, request):
147147

148148
# TODO Change to prefill queue
149149
if self.prefill_worker_client is not None:
150-
prefill_response = await anext(
151-
await self.prefill_worker_client.round_robin(prefill_request)
152-
)
150+
try:
151+
prefill_response = await anext(
152+
await self.prefill_worker_client.round_robin(
153+
prefill_request, context
154+
)
155+
)
156+
except Exception as e:
157+
# TODO: Cancellation does not propagate until the first token is received
158+
if context.is_stopped() or context.is_killed():
159+
logger.debug(f"Aborted Remote Prefill Request ID: {request_id}")
160+
# TODO: Raise asyncio.CancelledError into bindings
161+
return
162+
raise e
163+
153164
prefill_response = MyRequestOutput.model_validate_json(
154165
prefill_response.data()
155166
)
@@ -162,14 +173,20 @@ async def generate(self, request):
162173
] = prefill_response.kv_transfer_params
163174

164175
async for tok in self.generate_tokens(prompt, sampling_params, request_id):
176+
if context.is_stopped() or context.is_killed():
177+
await self.engine_client.abort(request_id)
178+
logger.debug(f"Aborted Request ID: {request_id}")
179+
# TODO: Raise asyncio.CancelledError into bindings
180+
break
181+
165182
yield tok
166183

167184

168185
class PrefillWorkerHandler(BaseWorkerHandler):
169186
def __init__(self, component, engine, default_sampling_params):
170187
super().__init__(component, engine, default_sampling_params)
171188

172-
async def generate(self, request):
189+
async def generate(self, request, context):
173190
request_id = request["request_id"]
174191
logger.debug(f"New Prefill Request ID: {request_id}")
175192

@@ -181,6 +198,12 @@ async def generate(self, request):
181198
# Generate only 1 token in prefill
182199
try:
183200
async for res in gen:
201+
if context.is_stopped() or context.is_killed():
202+
await self.engine_client.abort(request_id)
203+
logger.debug(f"Aborted Prefill Request ID: {request_id}")
204+
# TODO: Raise asyncio.CancelledError into bindings
205+
break
206+
184207
logger.debug(f"kv transfer params: {res.kv_transfer_params}")
185208
yield MyRequestOutput(
186209
request_id=res.request_id,

0 commit comments

Comments
 (0)