From d77cd443a330eb9ccb48e24af38bf77a1bd9fd46 Mon Sep 17 00:00:00 2001 From: Ronan Lamy Date: Fri, 18 Oct 2024 20:05:20 +0100 Subject: [PATCH 1/4] Use threading in AsyncMapper.produce() --- src/datachain/asyn.py | 19 +++++++++++++++---- tests/unit/test_asyn.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/src/datachain/asyn.py b/src/datachain/asyn.py index 4e42f4a8c..7c94190fa 100644 --- a/src/datachain/asyn.py +++ b/src/datachain/asyn.py @@ -1,5 +1,12 @@ import asyncio -from collections.abc import AsyncIterable, Awaitable, Coroutine, Iterable, Iterator +from collections.abc import ( + AsyncIterable, + Awaitable, + Coroutine, + Generator, + Iterable, + Iterator, +) from concurrent.futures import ThreadPoolExecutor from heapq import heappop, heappush from typing import Any, Callable, Generic, Optional, TypeVar @@ -54,9 +61,13 @@ def start_task(self, coro: Coroutine) -> asyncio.Task: task.add_done_callback(self._tasks.discard) return task - async def produce(self) -> None: + def _produce(self) -> None: for item in self.iterable: - await self.work_queue.put(item) + fut = asyncio.run_coroutine_threadsafe(self.work_queue.put(item), self.loop) + fut.result() # wait until the item is in the queue + + async def produce(self) -> None: + await self.to_thread(self._produce) async def worker(self) -> None: while (item := await self.work_queue.get()) is not None: @@ -132,7 +143,7 @@ async def _break_iteration(self) -> None: self.result_queue.get_nowait() await self.result_queue.put(None) - def iterate(self, timeout=None) -> Iterable[ResultT]: + def iterate(self, timeout=None) -> Generator[ResultT, None, None]: init = asyncio.run_coroutine_threadsafe(self.init(), self.loop) init.result(timeout=1) async_run = asyncio.run_coroutine_threadsafe(self.run(), self.loop) diff --git a/tests/unit/test_asyn.py b/tests/unit/test_asyn.py index a97bb732a..e102dd3a2 100644 --- a/tests/unit/test_asyn.py +++ b/tests/unit/test_asyn.py @@ -2,6 +2,7 @@ import functools from collections import Counter from contextlib import contextmanager +from queue import Queue import pytest from fsspec.asyn import sync @@ -111,6 +112,37 @@ async def process(row): list(mapper.iterate(timeout=4)) +@pytest.mark.parametrize("create_mapper", [AsyncMapper, OrderedMapper]) +def test_mapper_deadlock(create_mapper): + queue = Queue() + inputs = range(50) + + def as_iter(queue): + while (item := queue.get()) is not None: + yield item + + async def process(x): + return x + + mapper = create_mapper(process, as_iter(queue), workers=10, loop=get_loop()) + it = mapper.iterate(timeout=4) + for i in inputs: + queue.put(i) + + # Check that we can get as many objects out as we put in, without deadlock + result = [] + for _ in range(len(inputs)): + result.append(next(it)) + if mapper.order_preserving: + assert result == list(inputs) + else: + assert set(result) == set(inputs) + + # Check that iteration terminates cleanly + queue.put(None) + assert list(it) == [] + + @pytest.mark.parametrize("create_mapper", [AsyncMapper, OrderedMapper]) @settings(deadline=None) @given( From a0579ae0ad4dcb2ece0b04818b82d31254e1ec56 Mon Sep 17 00:00:00 2001 From: Ronan Lamy Date: Wed, 9 Oct 2024 15:39:21 +0100 Subject: [PATCH 2/4] Implement prefetching in .gen() and .map() --- src/datachain/data_storage/warehouse.py | 5 +- src/datachain/lib/dc.py | 7 ++- src/datachain/lib/file.py | 5 ++ src/datachain/lib/settings.py | 12 ++++- src/datachain/lib/udf.py | 63 ++++++++++++++++++------- src/datachain/query/dataset.py | 52 ++++++++++---------- tests/func/test_datachain.py | 7 ++- 7 files changed, 101 insertions(+), 50 deletions(-) diff --git a/src/datachain/data_storage/warehouse.py b/src/datachain/data_storage/warehouse.py index 8e4bdd51a..f3e87cd72 100644 --- a/src/datachain/data_storage/warehouse.py +++ b/src/datachain/data_storage/warehouse.py @@ -232,7 +232,10 @@ def dataset_select_paginated( if limit < page_size: paginated_query = paginated_query.limit(None).limit(limit) - results = self.dataset_rows_select(paginated_query.offset(offset)) + # Ensure we're using a thread-local connection + with self.clone() as wh: + # Cursor results are not thread-safe, so we convert them to a list + results = list(wh.dataset_rows_select(paginated_query.offset(offset))) processed = False for row in results: diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 5e04c349d..3d67688af 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -334,6 +334,7 @@ def settings( parallel=None, workers=None, min_task_size=None, + prefetch: Optional[int] = None, sys: Optional[bool] = None, ) -> "Self": """Change settings for chain. @@ -360,7 +361,7 @@ def settings( if sys is None: sys = self._sys settings = copy.copy(self._settings) - settings.add(Settings(cache, parallel, workers, min_task_size)) + settings.add(Settings(cache, parallel, workers, min_task_size, prefetch)) return self._evolve(settings=settings, _sys=sys) def reset_settings(self, settings: Optional[Settings] = None) -> "Self": @@ -882,6 +883,8 @@ def map( ``` """ udf_obj = self._udf_to_obj(Mapper, func, params, output, signal_map) + if (prefetch := self._settings.prefetch) is not None: + udf_obj.prefetch = prefetch return self._evolve( query=self._query.add_signals( @@ -919,6 +922,8 @@ def gen( ``` """ udf_obj = self._udf_to_obj(Generator, func, params, output, signal_map) + if (prefetch := self._settings.prefetch) is not None: + udf_obj.prefetch = prefetch return self._evolve( query=self._query.generate( udf_obj.to_udf_wrapper(), diff --git a/src/datachain/lib/file.py b/src/datachain/lib/file.py index caf19015a..10536eb04 100644 --- a/src/datachain/lib/file.py +++ b/src/datachain/lib/file.py @@ -268,6 +268,11 @@ def ensure_cached(self) -> None: client = self._catalog.get_client(self.source) client.download(self, callback=self._download_cb) + async def _prefetch(self) -> None: + if self._caching_enabled: + client = self._catalog.get_client(self.source) + await client._download(self, callback=self._download_cb) + def get_local_path(self) -> Optional[str]: """Return path to a file in a local cache. diff --git a/src/datachain/lib/settings.py b/src/datachain/lib/settings.py index 1f3722a44..fe294d950 100644 --- a/src/datachain/lib/settings.py +++ b/src/datachain/lib/settings.py @@ -7,11 +7,19 @@ def __init__(self, msg): class Settings: - def __init__(self, cache=None, parallel=None, workers=None, min_task_size=None): + def __init__( + self, + cache=None, + parallel=None, + workers=None, + min_task_size=None, + prefetch=None, + ): self._cache = cache self.parallel = parallel self._workers = workers self.min_task_size = min_task_size + self.prefetch = prefetch if not isinstance(cache, bool) and cache is not None: raise SettingsError( @@ -66,3 +74,5 @@ def add(self, settings: "Settings"): self.parallel = settings.parallel or self.parallel self._workers = settings._workers or self._workers self.min_task_size = settings.min_task_size or self.min_task_size + if settings.prefetch is not None: + self.prefetch = settings.prefetch diff --git a/src/datachain/lib/udf.py b/src/datachain/lib/udf.py index 8faa13a29..d708c0330 100644 --- a/src/datachain/lib/udf.py +++ b/src/datachain/lib/udf.py @@ -1,3 +1,4 @@ +import contextlib import sys import traceback from collections.abc import Iterable, Iterator, Mapping, Sequence @@ -7,6 +8,7 @@ from fsspec.callbacks import DEFAULT_CALLBACK, Callback from pydantic import BaseModel +from datachain.asyn import AsyncMapper from datachain.dataset import RowDict from datachain.lib.convert.flatten import flatten from datachain.lib.data_model import DataValue @@ -21,6 +23,8 @@ ) if TYPE_CHECKING: + from collections import abc + from typing_extensions import Self from datachain.catalog import Catalog @@ -276,9 +280,18 @@ def process_safe(self, obj_rows): return result_objs +async def _prefetch_input(row): + for obj in row: + if isinstance(obj, File): + await obj._prefetch() + return row + + class Mapper(UDFBase): """Inherit from this class to pass to `DataChain.map()`.""" + prefetch: int = 2 + def run( self, udf_fields: "Sequence[str]", @@ -290,16 +303,22 @@ def run( ) -> Iterator[Iterable[UDFResult]]: self.catalog = catalog self.setup() - - for row in udf_inputs: - id_, *udf_args = self._prepare_row_and_id( - row, udf_fields, cache, download_cb - ) - result_objs = self.process_safe(udf_args) - udf_output = self._flatten_row(result_objs) - output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))] - processed_cb.relative_update(1) - yield output + prepared_inputs: abc.Generator[Sequence[Any], None, None] = ( + self._prepare_row_and_id(row, udf_fields, cache, download_cb) + for row in udf_inputs + ) + if self.prefetch > 0: + prepared_inputs = AsyncMapper( + _prefetch_input, prepared_inputs, workers=self.prefetch + ).iterate() + + with contextlib.closing(prepared_inputs): + for id_, *udf_args in prepared_inputs: + result_objs = self.process_safe(udf_args) + udf_output = self._flatten_row(result_objs) + output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))] + processed_cb.relative_update(1) + yield output self.teardown() @@ -349,6 +368,7 @@ class Generator(UDFBase): """Inherit from this class to pass to `DataChain.gen()`.""" is_output_batched = True + prefetch: int = 2 def run( self, @@ -361,14 +381,21 @@ def run( ) -> Iterator[Iterable[UDFResult]]: self.catalog = catalog self.setup() - - for row in udf_inputs: - udf_args = self._prepare_row(row, udf_fields, cache, download_cb) - result_objs = self.process_safe(udf_args) - udf_outputs = (self._flatten_row(row) for row in result_objs) - output = (dict(zip(self.signal_names, row)) for row in udf_outputs) - processed_cb.relative_update(1) - yield output + prepared_inputs: abc.Generator[Sequence[Any], None, None] = ( + self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs + ) + if self.prefetch > 0: + prepared_inputs = AsyncMapper( + _prefetch_input, prepared_inputs, workers=self.prefetch + ).iterate() + + with contextlib.closing(prepared_inputs): + for row in prepared_inputs: + result_objs = self.process_safe(row) + udf_outputs = (self._flatten_row(row) for row in result_objs) + output = (dict(zip(self.signal_names, row)) for row in udf_outputs) + processed_cb.relative_update(1) + yield output self.teardown() diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index c455c01e0..86f0eef79 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -473,33 +473,31 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: # Otherwise process single-threaded (faster for smaller UDFs) warehouse = self.catalog.warehouse - with contextlib.closing( - batching(warehouse.dataset_select_paginated, query) - ) as udf_inputs: - download_cb = get_download_callback() - processed_cb = get_processed_callback() - generated_cb = get_generated_callback(self.is_generator) - try: - udf_results = self.udf.run( - udf_fields, - udf_inputs, - self.catalog, - self.is_generator, - self.cache, - download_cb, - processed_cb, - ) - process_udf_outputs( - warehouse, - udf_table, - udf_results, - self.udf, - cb=generated_cb, - ) - finally: - download_cb.close() - processed_cb.close() - generated_cb.close() + udf_inputs = batching(warehouse.dataset_select_paginated, query) + download_cb = get_download_callback() + processed_cb = get_processed_callback() + generated_cb = get_generated_callback(self.is_generator) + try: + udf_results = self.udf.run( + udf_fields, + udf_inputs, + self.catalog, + self.is_generator, + self.cache, + download_cb, + processed_cb, + ) + process_udf_outputs( + warehouse, + udf_table, + udf_results, + self.udf, + cb=generated_cb, + ) + finally: + download_cb.close() + processed_cb.close() + generated_cb.close() warehouse.insert_rows_done(udf_table) diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index e121ef03b..fb11cad6d 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -213,17 +213,20 @@ def test_from_storage_dependencies(cloud_test_catalog, cloud_type): @pytest.mark.parametrize("use_cache", [True, False]) -def test_map_file(cloud_test_catalog, use_cache): +@pytest.mark.parametrize("prefetch", [0, 2]) +def test_map_file(cloud_test_catalog, use_cache, prefetch): ctc = cloud_test_catalog def new_signal(file: File) -> str: + assert bool(file.get_local_path()) is (use_cache and prefetch > 0) with file.open() as f: return file.name + " -> " + f.read().decode("utf-8") dc = ( DataChain.from_storage(ctc.src_uri, session=ctc.session) - .settings(cache=use_cache) + .settings(cache=use_cache, prefetch=prefetch) .map(signal=new_signal) + .save() ) expected = { From 59896cd22fb62d8815562b013ee20a72693e27f8 Mon Sep 17 00:00:00 2001 From: Ronan Lamy Date: Tue, 29 Oct 2024 17:01:35 +0000 Subject: [PATCH 3/4] Avoid user code error in name_len() --- tests/scripts/name_len_slow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/scripts/name_len_slow.py b/tests/scripts/name_len_slow.py index df930a094..5088bc3ee 100644 --- a/tests/scripts/name_len_slow.py +++ b/tests/scripts/name_len_slow.py @@ -36,5 +36,5 @@ def name_len(file): "gs://dvcx-datalakes/dogs-and-cats/", anon=True, ).filter(C("file.path").glob("*cat*")).settings(parallel=1).map( - name_len, params=["file.path"], output={"name_len": int} + name_len, params=["file"], output={"name_len": int} ).save("name_len") From b54545edf9157d5220c8eaffe5fa0b92360b1174 Mon Sep 17 00:00:00 2001 From: skshetry <18718008+skshetry@users.noreply.github.com> Date: Wed, 20 Nov 2024 08:17:50 +0545 Subject: [PATCH 4/4] asyncmapper: shutdown producer on generator close (#597) --- src/datachain/asyn.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/datachain/asyn.py b/src/datachain/asyn.py index 7c94190fa..1b87afc41 100644 --- a/src/datachain/asyn.py +++ b/src/datachain/asyn.py @@ -1,4 +1,5 @@ import asyncio +import threading from collections.abc import ( AsyncIterable, Awaitable, @@ -54,6 +55,7 @@ def __init__( self.loop = get_loop() if loop is None else loop self.pool = ThreadPoolExecutor(workers) self._tasks: set[asyncio.Task] = set() + self._shutdown_producer = threading.Event() def start_task(self, coro: Coroutine) -> asyncio.Task: task = self.loop.create_task(coro) @@ -63,12 +65,30 @@ def start_task(self, coro: Coroutine) -> asyncio.Task: def _produce(self) -> None: for item in self.iterable: + if self._shutdown_producer.is_set(): + return fut = asyncio.run_coroutine_threadsafe(self.work_queue.put(item), self.loop) fut.result() # wait until the item is in the queue async def produce(self) -> None: await self.to_thread(self._produce) + def shutdown_producer(self) -> None: + """ + Signal the producer to stop and drain any remaining items from the work_queue. + + This method sets an internal event, `_shutdown_producer`, which tells the + producer that it should stop adding items to the queue. To ensure that the + producer notices this signal promptly, we also attempt to drain any items + currently in the queue, clearing it so that the event can be checked without + delay. + """ + self._shutdown_producer.set() + q = self.work_queue + while not q.empty(): + q.get_nowait() + q.task_done() + async def worker(self) -> None: while (item := await self.work_queue.get()) is not None: try: @@ -156,6 +176,7 @@ def iterate(self, timeout=None) -> Generator[ResultT, None, None]: if exc := async_run.exception(): raise exc finally: + self.shutdown_producer() if not async_run.done(): async_run.cancel()