Skip to content

Commit f20a1e9

Browse files
committed
small fix
1 parent ee607bd commit f20a1e9

File tree

1 file changed

+4
-8
lines changed
  • multimodal/vl2l/src/mlperf_inference_multimodal_vl2l

1 file changed

+4
-8
lines changed

multimodal/vl2l/src/mlperf_inference_multimodal_vl2l/task.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ async def _cancel_all_tasks() -> None:
8484
tasks = [
8585
task
8686
for task in asyncio.all_tasks(self.event_loop)
87-
if not task.done()
87+
if task is not asyncio.current_task() and not task.done()
8888
]
8989
for task in tasks:
9090
task.cancel()
@@ -183,9 +183,7 @@ def estimated_num_performance_samples(self) -> int:
183183
"""
184184
estimation_indices = random.sample(
185185
range(self.total_num_samples),
186-
k=min(
187-
MAX_NUM_ESTIMATION_PERFORMANCE_SAMPLES,
188-
self.total_num_samples),
186+
k=min(MAX_NUM_ESTIMATION_PERFORMANCE_SAMPLES, self.total_num_samples),
189187
)
190188
estimation_samples = [
191189
self.formulate_loaded_sample(
@@ -250,8 +248,7 @@ def _unload_samples_from_ram(query_sample_indices: list[int]) -> None:
250248
_unload_samples_from_ram,
251249
)
252250

253-
async def _query_endpoint_async_batch(
254-
self, query_sample: lg.QuerySample) -> None:
251+
async def _query_endpoint_async_batch(self, query_sample: lg.QuerySample) -> None:
255252
"""Query the endpoint through the async OpenAI API client."""
256253
try:
257254
sample = self.loaded_samples[query_sample.index]
@@ -328,8 +325,7 @@ async def _query_endpoint_async_batch(
328325
],
329326
)
330327

331-
async def _query_endpoint_async_stream(
332-
self, query_sample: lg.QuerySample) -> None:
328+
async def _query_endpoint_async_stream(self, query_sample: lg.QuerySample) -> None:
333329
"""Query the endpoint through the async OpenAI API client."""
334330
ttft_set = False
335331
try:

0 commit comments

Comments
 (0)