@@ -84,7 +84,7 @@ async def _cancel_all_tasks() -> None:
8484 tasks = [
8585 task
8686 for task in asyncio .all_tasks (self .event_loop )
87- if not task .done ()
87+ if task is not asyncio . current_task () and not task .done ()
8888 ]
8989 for task in tasks :
9090 task .cancel ()
@@ -183,9 +183,7 @@ def estimated_num_performance_samples(self) -> int:
183183 """
184184 estimation_indices = random .sample (
185185 range (self .total_num_samples ),
186- k = min (
187- MAX_NUM_ESTIMATION_PERFORMANCE_SAMPLES ,
188- self .total_num_samples ),
186+ k = min (MAX_NUM_ESTIMATION_PERFORMANCE_SAMPLES , self .total_num_samples ),
189187 )
190188 estimation_samples = [
191189 self .formulate_loaded_sample (
@@ -250,8 +248,7 @@ def _unload_samples_from_ram(query_sample_indices: list[int]) -> None:
250248 _unload_samples_from_ram ,
251249 )
252250
253- async def _query_endpoint_async_batch (
254- self , query_sample : lg .QuerySample ) -> None :
251+ async def _query_endpoint_async_batch (self , query_sample : lg .QuerySample ) -> None :
255252 """Query the endpoint through the async OpenAI API client."""
256253 try :
257254 sample = self .loaded_samples [query_sample .index ]
@@ -328,8 +325,7 @@ async def _query_endpoint_async_batch(
328325 ],
329326 )
330327
331- async def _query_endpoint_async_stream (
332- self , query_sample : lg .QuerySample ) -> None :
328+ async def _query_endpoint_async_stream (self , query_sample : lg .QuerySample ) -> None :
333329 """Query the endpoint through the async OpenAI API client."""
334330 ttft_set = False
335331 try :
0 commit comments