diff --git a/src/datachain/asyn.py b/src/datachain/asyn.py index 7c94190fa..1b87afc41 100644 --- a/src/datachain/asyn.py +++ b/src/datachain/asyn.py @@ -1,4 +1,5 @@ import asyncio +import threading from collections.abc import ( AsyncIterable, Awaitable, @@ -54,6 +55,7 @@ def __init__( self.loop = get_loop() if loop is None else loop self.pool = ThreadPoolExecutor(workers) self._tasks: set[asyncio.Task] = set() + self._shutdown_producer = threading.Event() def start_task(self, coro: Coroutine) -> asyncio.Task: task = self.loop.create_task(coro) @@ -63,12 +65,30 @@ def start_task(self, coro: Coroutine) -> asyncio.Task: def _produce(self) -> None: for item in self.iterable: + if self._shutdown_producer.is_set(): + return fut = asyncio.run_coroutine_threadsafe(self.work_queue.put(item), self.loop) fut.result() # wait until the item is in the queue async def produce(self) -> None: await self.to_thread(self._produce) + def shutdown_producer(self) -> None: + """ + Signal the producer to stop and drain any remaining items from the work_queue. + + This method sets an internal event, `_shutdown_producer`, which tells the + producer that it should stop adding items to the queue. To ensure that the + producer notices this signal promptly, we also attempt to drain any items + currently in the queue, clearing it so that the event can be checked without + delay. + """ + self._shutdown_producer.set() + q = self.work_queue + while not q.empty(): + q.get_nowait() + q.task_done() + async def worker(self) -> None: while (item := await self.work_queue.get()) is not None: try: @@ -156,6 +176,7 @@ def iterate(self, timeout=None) -> Generator[ResultT, None, None]: if exc := async_run.exception(): raise exc finally: + self.shutdown_producer() if not async_run.done(): async_run.cancel()