Skip to content

Commit

Permalink
asyncmapper: shutdown producer on generator close (#597)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Nov 20, 2024
1 parent 9fd3155 commit 56cc2ad
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/datachain/asyn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import threading
from collections.abc import (
AsyncIterable,
Awaitable,
Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(
self.loop = get_loop() if loop is None else loop
self.pool = ThreadPoolExecutor(workers)
self._tasks: set[asyncio.Task] = set()
self._shutdown_producer = threading.Event()

def start_task(self, coro: Coroutine) -> asyncio.Task:
task = self.loop.create_task(coro)
Expand All @@ -63,12 +65,30 @@ def start_task(self, coro: Coroutine) -> asyncio.Task:

def _produce(self) -> None:
for item in self.iterable:
if self._shutdown_producer.is_set():
return
fut = asyncio.run_coroutine_threadsafe(self.work_queue.put(item), self.loop)
fut.result() # wait until the item is in the queue

async def produce(self) -> None:
await self.to_thread(self._produce)

def shutdown_producer(self) -> None:
"""
Signal the producer to stop and drain any remaining items from the work_queue.
This method sets an internal event, `_shutdown_producer`, which tells the
producer that it should stop adding items to the queue. To ensure that the
producer notices this signal promptly, we also attempt to drain any items
currently in the queue, clearing it so that the event can be checked without
delay.
"""
self._shutdown_producer.set()
q = self.work_queue
while not q.empty():
q.get_nowait()
q.task_done()

async def worker(self) -> None:
while (item := await self.work_queue.get()) is not None:
try:
Expand Down Expand Up @@ -156,6 +176,7 @@ def iterate(self, timeout=None) -> Generator[ResultT, None, None]:
if exc := async_run.exception():
raise exc
finally:
self.shutdown_producer()
if not async_run.done():
async_run.cancel()

Expand Down

0 comments on commit 56cc2ad

Please sign in to comment.