Skip to content

Commit

Permalink
Use threading in AsyncMapper.produce()
Browse files Browse the repository at this point in the history
  • Loading branch information
rlamy authored and skshetry committed Nov 13, 2024
1 parent 3f5e9f8 commit beadae3
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/datachain/asyn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import asyncio
from collections.abc import AsyncIterable, Awaitable, Coroutine, Iterable, Iterator
from collections.abc import (
AsyncIterable,
Awaitable,
Coroutine,
Generator,
Iterable,
Iterator,
)
from concurrent.futures import ThreadPoolExecutor
from heapq import heappop, heappush
from typing import Any, Callable, Generic, Optional, TypeVar
Expand Down Expand Up @@ -54,9 +61,13 @@ def start_task(self, coro: Coroutine) -> asyncio.Task:
task.add_done_callback(self._tasks.discard)
return task

async def produce(self) -> None:
def _produce(self) -> None:
for item in self.iterable:
await self.work_queue.put(item)
fut = asyncio.run_coroutine_threadsafe(self.work_queue.put(item), self.loop)
fut.result() # wait until the item is in the queue

async def produce(self) -> None:
await self.to_thread(self._produce)

async def worker(self) -> None:
while (item := await self.work_queue.get()) is not None:
Expand Down Expand Up @@ -132,7 +143,7 @@ async def _break_iteration(self) -> None:
self.result_queue.get_nowait()
await self.result_queue.put(None)

def iterate(self, timeout=None) -> Iterable[ResultT]:
def iterate(self, timeout=None) -> Generator[ResultT, None, None]:
init = asyncio.run_coroutine_threadsafe(self.init(), self.loop)
init.result(timeout=1)
async_run = asyncio.run_coroutine_threadsafe(self.run(), self.loop)
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/test_asyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
from collections import Counter
from contextlib import contextmanager
from queue import Queue

import pytest
from fsspec.asyn import sync
Expand Down Expand Up @@ -111,6 +112,37 @@ async def process(row):
list(mapper.iterate(timeout=4))


@pytest.mark.parametrize("create_mapper", [AsyncMapper, OrderedMapper])
def test_mapper_deadlock(create_mapper):
queue = Queue()
inputs = range(50)

def as_iter(queue):
while (item := queue.get()) is not None:
yield item

async def process(x):
return x

mapper = create_mapper(process, as_iter(queue), workers=10, loop=get_loop())
it = mapper.iterate(timeout=4)
for i in inputs:
queue.put(i)

# Check that we can get as many objects out as we put in, without deadlock
result = []
for _ in range(len(inputs)):
result.append(next(it))
if mapper.order_preserving:
assert result == list(inputs)
else:
assert set(result) == set(inputs)

# Check that iteration terminates cleanly
queue.put(None)
assert list(it) == []


@pytest.mark.parametrize("create_mapper", [AsyncMapper, OrderedMapper])
@settings(deadline=None)
@given(
Expand Down

0 comments on commit beadae3

Please sign in to comment.