Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement pre-fetching in map() and gen() #521

Merged
merged 4 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions src/datachain/asyn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import asyncio
from collections.abc import AsyncIterable, Awaitable, Coroutine, Iterable, Iterator
import threading
from collections.abc import (
AsyncIterable,
Awaitable,
Coroutine,
Generator,
Iterable,
Iterator,
)
from concurrent.futures import ThreadPoolExecutor
from heapq import heappop, heappush
from typing import Any, Callable, Generic, Optional, TypeVar
Expand Down Expand Up @@ -47,16 +55,39 @@ def __init__(
self.loop = get_loop() if loop is None else loop
self.pool = ThreadPoolExecutor(workers)
self._tasks: set[asyncio.Task] = set()
self._shutdown_producer = threading.Event()

def start_task(self, coro: Coroutine) -> asyncio.Task:
task = self.loop.create_task(coro)
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
return task

async def produce(self) -> None:
def _produce(self) -> None:
for item in self.iterable:
await self.work_queue.put(item)
if self._shutdown_producer.is_set():
return
fut = asyncio.run_coroutine_threadsafe(self.work_queue.put(item), self.loop)
fut.result() # wait until the item is in the queue

async def produce(self) -> None:
await self.to_thread(self._produce)

def shutdown_producer(self) -> None:
"""
Signal the producer to stop and drain any remaining items from the work_queue.

This method sets an internal event, `_shutdown_producer`, which tells the
producer that it should stop adding items to the queue. To ensure that the
producer notices this signal promptly, we also attempt to drain any items
currently in the queue, clearing it so that the event can be checked without
delay.
"""
self._shutdown_producer.set()
q = self.work_queue
while not q.empty():
q.get_nowait()
q.task_done()

async def worker(self) -> None:
while (item := await self.work_queue.get()) is not None:
Expand Down Expand Up @@ -132,7 +163,7 @@ async def _break_iteration(self) -> None:
self.result_queue.get_nowait()
await self.result_queue.put(None)

def iterate(self, timeout=None) -> Iterable[ResultT]:
def iterate(self, timeout=None) -> Generator[ResultT, None, None]:
init = asyncio.run_coroutine_threadsafe(self.init(), self.loop)
init.result(timeout=1)
async_run = asyncio.run_coroutine_threadsafe(self.run(), self.loop)
Expand All @@ -145,6 +176,7 @@ def iterate(self, timeout=None) -> Iterable[ResultT]:
if exc := async_run.exception():
raise exc
finally:
self.shutdown_producer()
if not async_run.done():
async_run.cancel()

Expand Down
5 changes: 4 additions & 1 deletion src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,10 @@ def dataset_select_paginated(
if limit < page_size:
paginated_query = paginated_query.limit(None).limit(limit)

results = self.dataset_rows_select(paginated_query.offset(offset))
# Ensure we're using a thread-local connection
with self.clone() as wh:
Comment on lines +235 to +236
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this clone() should be outside the loop? Otherwise, we'll not be reusing the connection.

# Cursor results are not thread-safe, so we convert them to a list
shcheklein marked this conversation as resolved.
Show resolved Hide resolved
results = list(wh.dataset_rows_select(paginated_query.offset(offset)))

processed = False
for row in results:
Expand Down
7 changes: 6 additions & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@
parallel=None,
workers=None,
min_task_size=None,
prefetch: Optional[int] = None,
Copy link
Member

@shcheklein shcheklein Oct 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q: why int? let's update the docs here (do we have some CI to detect these discrepancies btw (missing docs) cc @skshetry )

sys: Optional[bool] = None,
) -> "Self":
"""Change settings for chain.
Expand All @@ -360,7 +361,7 @@
if sys is None:
sys = self._sys
settings = copy.copy(self._settings)
settings.add(Settings(cache, parallel, workers, min_task_size))
settings.add(Settings(cache, parallel, workers, min_task_size, prefetch))
return self._evolve(settings=settings, _sys=sys)

def reset_settings(self, settings: Optional[Settings] = None) -> "Self":
Expand Down Expand Up @@ -882,6 +883,8 @@
```
"""
udf_obj = self._udf_to_obj(Mapper, func, params, output, signal_map)
if (prefetch := self._settings.prefetch) is not None:
udf_obj.prefetch = prefetch

return self._evolve(
query=self._query.add_signals(
Expand Down Expand Up @@ -919,6 +922,8 @@
```
"""
udf_obj = self._udf_to_obj(Generator, func, params, output, signal_map)
if (prefetch := self._settings.prefetch) is not None:
udf_obj.prefetch = prefetch

Check warning on line 926 in src/datachain/lib/dc.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/dc.py#L926

Added line #L926 was not covered by tests
return self._evolve(
query=self._query.generate(
udf_obj.to_udf_wrapper(),
Expand Down
5 changes: 5 additions & 0 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,11 @@ def ensure_cached(self) -> None:
client = self._catalog.get_client(self.source)
client.download(self, callback=self._download_cb)

async def _prefetch(self) -> None:
if self._caching_enabled:
client = self._catalog.get_client(self.source)
await client._download(self, callback=self._download_cb)

def get_local_path(self) -> Optional[str]:
"""Return path to a file in a local cache.

Expand Down
12 changes: 11 additions & 1 deletion src/datachain/lib/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,19 @@ def __init__(self, msg):


class Settings:
def __init__(self, cache=None, parallel=None, workers=None, min_task_size=None):
def __init__(
self,
cache=None,
parallel=None,
workers=None,
min_task_size=None,
prefetch=None,
):
self._cache = cache
self.parallel = parallel
self._workers = workers
self.min_task_size = min_task_size
self.prefetch = prefetch

if not isinstance(cache, bool) and cache is not None:
raise SettingsError(
Expand Down Expand Up @@ -66,3 +74,5 @@ def add(self, settings: "Settings"):
self.parallel = settings.parallel or self.parallel
self._workers = settings._workers or self._workers
self.min_task_size = settings.min_task_size or self.min_task_size
if settings.prefetch is not None:
shcheklein marked this conversation as resolved.
Show resolved Hide resolved
self.prefetch = settings.prefetch
63 changes: 45 additions & 18 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import sys
import traceback
from collections.abc import Iterable, Iterator, Mapping, Sequence
Expand All @@ -7,6 +8,7 @@
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
from pydantic import BaseModel

from datachain.asyn import AsyncMapper
from datachain.dataset import RowDict
from datachain.lib.convert.flatten import flatten
from datachain.lib.data_model import DataValue
Expand All @@ -21,6 +23,8 @@
)

if TYPE_CHECKING:
from collections import abc

from typing_extensions import Self

from datachain.catalog import Catalog
Expand Down Expand Up @@ -276,9 +280,18 @@ def process_safe(self, obj_rows):
return result_objs


async def _prefetch_input(row):
for obj in row:
if isinstance(obj, File):
await obj._prefetch()
return row


class Mapper(UDFBase):
"""Inherit from this class to pass to `DataChain.map()`."""

prefetch: int = 2

def run(
self,
udf_fields: "Sequence[str]",
Expand All @@ -290,16 +303,22 @@ def run(
) -> Iterator[Iterable[UDFResult]]:
self.catalog = catalog
self.setup()

for row in udf_inputs:
id_, *udf_args = self._prepare_row_and_id(
row, udf_fields, cache, download_cb
)
result_objs = self.process_safe(udf_args)
udf_output = self._flatten_row(result_objs)
output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))]
processed_cb.relative_update(1)
yield output
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
self._prepare_row_and_id(row, udf_fields, cache, download_cb)
for row in udf_inputs
)
if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()

with contextlib.closing(prepared_inputs):
for id_, *udf_args in prepared_inputs:
result_objs = self.process_safe(udf_args)
udf_output = self._flatten_row(result_objs)
output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))]
processed_cb.relative_update(1)
yield output

self.teardown()

Expand Down Expand Up @@ -349,6 +368,7 @@ class Generator(UDFBase):
"""Inherit from this class to pass to `DataChain.gen()`."""

is_output_batched = True
prefetch: int = 2

def run(
self,
Expand All @@ -361,14 +381,21 @@ def run(
) -> Iterator[Iterable[UDFResult]]:
self.catalog = catalog
self.setup()

for row in udf_inputs:
udf_args = self._prepare_row(row, udf_fields, cache, download_cb)
result_objs = self.process_safe(udf_args)
udf_outputs = (self._flatten_row(row) for row in result_objs)
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
processed_cb.relative_update(1)
yield output
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs
)
if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()

with contextlib.closing(prepared_inputs):
for row in prepared_inputs:
result_objs = self.process_safe(row)
udf_outputs = (self._flatten_row(row) for row in result_objs)
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
processed_cb.relative_update(1)
yield output

self.teardown()

Expand Down
52 changes: 25 additions & 27 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,33 +473,31 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
# Otherwise process single-threaded (faster for smaller UDFs)
warehouse = self.catalog.warehouse

with contextlib.closing(
batching(warehouse.dataset_select_paginated, query)
) as udf_inputs:
download_cb = get_download_callback()
processed_cb = get_processed_callback()
generated_cb = get_generated_callback(self.is_generator)
try:
udf_results = self.udf.run(
udf_fields,
udf_inputs,
self.catalog,
self.is_generator,
self.cache,
download_cb,
processed_cb,
)
process_udf_outputs(
warehouse,
udf_table,
udf_results,
self.udf,
cb=generated_cb,
)
finally:
download_cb.close()
processed_cb.close()
generated_cb.close()
udf_inputs = batching(warehouse.dataset_select_paginated, query)
download_cb = get_download_callback()
processed_cb = get_processed_callback()
generated_cb = get_generated_callback(self.is_generator)
try:
udf_results = self.udf.run(
udf_fields,
udf_inputs,
self.catalog,
self.is_generator,
self.cache,
download_cb,
processed_cb,
)
process_udf_outputs(
warehouse,
udf_table,
udf_results,
self.udf,
cb=generated_cb,
)
finally:
download_cb.close()
processed_cb.close()
generated_cb.close()

warehouse.insert_rows_done(udf_table)

Expand Down
7 changes: 5 additions & 2 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,17 +213,20 @@ def test_from_storage_dependencies(cloud_test_catalog, cloud_type):


@pytest.mark.parametrize("use_cache", [True, False])
def test_map_file(cloud_test_catalog, use_cache):
@pytest.mark.parametrize("prefetch", [0, 2])
def test_map_file(cloud_test_catalog, use_cache, prefetch):
ctc = cloud_test_catalog

def new_signal(file: File) -> str:
assert bool(file.get_local_path()) is (use_cache and prefetch > 0)
with file.open() as f:
return file.name + " -> " + f.read().decode("utf-8")

dc = (
DataChain.from_storage(ctc.src_uri, session=ctc.session)
.settings(cache=use_cache)
.settings(cache=use_cache, prefetch=prefetch)
.map(signal=new_signal)
.save()
)

expected = {
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/name_len_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ def name_len(file):
"gs://dvcx-datalakes/dogs-and-cats/",
anon=True,
).filter(C("file.path").glob("*cat*")).settings(parallel=1).map(
name_len, params=["file.path"], output={"name_len": int}
name_len, params=["file"], output={"name_len": int}
).save("name_len")
Loading