@@ -32,7 +32,7 @@ def __init__(self, component, engine, default_sampling_params):
3232 self .kv_publisher = None
3333
3434 @abstractmethod
35- async def generate (self , request ) -> AsyncGenerator [dict , None ]:
35+ async def generate (self , request , context ) -> AsyncGenerator [dict , None ]:
3636 raise NotImplementedError
3737
3838 async def clear_kv_blocks (self , request = None ):
@@ -110,7 +110,7 @@ def cleanup(self):
110110 self ._prefill_check_task .cancel ()
111111 super ().cleanup ()
112112
113- async def generate (self , request ):
113+ async def generate (self , request , context ):
114114 request_id = str (uuid .uuid4 ().hex )
115115 logger .debug (f"New Request ID: { request_id } " )
116116
@@ -147,9 +147,20 @@ async def generate(self, request):
147147
148148 # TODO Change to prefill queue
149149 if self .prefill_worker_client is not None :
150- prefill_response = await anext (
151- await self .prefill_worker_client .round_robin (prefill_request )
152- )
150+ try :
151+ prefill_response = await anext (
152+ await self .prefill_worker_client .round_robin (
153+ prefill_request , context
154+ )
155+ )
156+ except Exception as e :
157+ # TODO: Cancellation does not propagate until the first token is received
158+ if context .is_stopped () or context .is_killed ():
159+ logger .debug (f"Aborted Remote Prefill Request ID: { request_id } " )
160+ # TODO: Raise asyncio.CancelledError into bindings
161+ return
162+ raise e
163+
153164 prefill_response = MyRequestOutput .model_validate_json (
154165 prefill_response .data ()
155166 )
@@ -162,14 +173,20 @@ async def generate(self, request):
162173 ] = prefill_response .kv_transfer_params
163174
164175 async for tok in self .generate_tokens (prompt , sampling_params , request_id ):
176+ if context .is_stopped () or context .is_killed ():
177+ await self .engine_client .abort (request_id )
178+ logger .debug (f"Aborted Request ID: { request_id } " )
179+ # TODO: Raise asyncio.CancelledError into bindings
180+ break
181+
165182 yield tok
166183
167184
168185class PrefillWorkerHandler (BaseWorkerHandler ):
169186 def __init__ (self , component , engine , default_sampling_params ):
170187 super ().__init__ (component , engine , default_sampling_params )
171188
172- async def generate (self , request ):
189+ async def generate (self , request , context ):
173190 request_id = request ["request_id" ]
174191 logger .debug (f"New Prefill Request ID: { request_id } " )
175192
@@ -181,6 +198,12 @@ async def generate(self, request):
181198 # Generate only 1 token in prefill
182199 try :
183200 async for res in gen :
201+ if context .is_stopped () or context .is_killed ():
202+ await self .engine_client .abort (request_id )
203+ logger .debug (f"Aborted Prefill Request ID: { request_id } " )
204+ # TODO: Raise asyncio.CancelledError into bindings
205+ break
206+
184207 logger .debug (f"kv transfer params: { res .kv_transfer_params } " )
185208 yield MyRequestOutput (
186209 request_id = res .request_id ,
0 commit comments