Skip to content

Commit

Permalink
Implement prefetching in .gen() and .map()
Browse files Browse the repository at this point in the history
  • Loading branch information
rlamy committed Oct 18, 2024
1 parent 52de616 commit d1d1457
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 50 deletions.
5 changes: 4 additions & 1 deletion src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,10 @@ def dataset_select_paginated(
if limit < page_size:
paginated_query = paginated_query.limit(None).limit(limit)

results = self.dataset_rows_select(paginated_query.offset(offset))
# Ensure we're using a thread-local connection
with self.clone() as wh:
# Cursor results are not thread-safe, so we convert them to a list
results = list(wh.dataset_rows_select(paginated_query.offset(offset)))

processed = False
for row in results:
Expand Down
7 changes: 6 additions & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def settings(
parallel=None,
workers=None,
min_task_size=None,
prefetch: Optional[int] = None,
sys: Optional[bool] = None,
) -> "Self":
"""Change settings for chain.
Expand All @@ -351,7 +352,7 @@ def settings(
if sys is None:
sys = self._sys
settings = copy.copy(self._settings)
settings.add(Settings(cache, parallel, workers, min_task_size))
settings.add(Settings(cache, parallel, workers, min_task_size, prefetch))
return self._evolve(settings=settings, _sys=sys)

def reset_settings(self, settings: Optional[Settings] = None) -> "Self":
Expand Down Expand Up @@ -801,6 +802,8 @@ def map(
```
"""
udf_obj = self._udf_to_obj(Mapper, func, params, output, signal_map)
if (prefetch := self._settings.prefetch) is not None:
udf_obj.prefetch = prefetch

return self._evolve(
query=self._query.add_signals(
Expand Down Expand Up @@ -838,6 +841,8 @@ def gen(
```
"""
udf_obj = self._udf_to_obj(Generator, func, params, output, signal_map)
if (prefetch := self._settings.prefetch) is not None:
udf_obj.prefetch = prefetch
return self._evolve(
query=self._query.generate(
udf_obj.to_udf_wrapper(),
Expand Down
5 changes: 5 additions & 0 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,11 @@ def ensure_cached(self) -> None:
client = self._catalog.get_client(self.source)
client.download(self, callback=self._download_cb)

async def _prefetch(self) -> None:
if self._caching_enabled:
client = self._catalog.get_client(self.source)
await client._download(self, callback=self._download_cb)

def get_local_path(self) -> Optional[str]:
"""Return path to a file in a local cache.
Expand Down
12 changes: 11 additions & 1 deletion src/datachain/lib/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,19 @@ def __init__(self, msg):


class Settings:
def __init__(self, cache=None, parallel=None, workers=None, min_task_size=None):
def __init__(
self,
cache=None,
parallel=None,
workers=None,
min_task_size=None,
prefetch=None,
):
self._cache = cache
self.parallel = parallel
self._workers = workers
self.min_task_size = min_task_size
self.prefetch = prefetch

if not isinstance(cache, bool) and cache is not None:
raise SettingsError(
Expand Down Expand Up @@ -66,3 +74,5 @@ def add(self, settings: "Settings"):
self.parallel = settings.parallel or self.parallel
self._workers = settings._workers or self._workers
self.min_task_size = settings.min_task_size or self.min_task_size
if settings.prefetch is not None:
self.prefetch = settings.prefetch
63 changes: 45 additions & 18 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import sys
import traceback
from collections.abc import Iterable, Iterator, Mapping, Sequence
Expand All @@ -7,6 +8,7 @@
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
from pydantic import BaseModel

from datachain.asyn import AsyncMapper
from datachain.dataset import RowDict
from datachain.lib.convert.flatten import flatten
from datachain.lib.data_model import DataValue
Expand All @@ -22,6 +24,8 @@
)

if TYPE_CHECKING:
from collections import abc

from typing_extensions import Self

from datachain.catalog import Catalog
Expand Down Expand Up @@ -276,9 +280,18 @@ def process_safe(self, obj_rows):
return result_objs


async def _prefetch_input(row):
for obj in row:
if isinstance(obj, File):
await obj._prefetch()
return row


class Mapper(UDFBase):
"""Inherit from this class to pass to `DataChain.map()`."""

prefetch: int = 2

def run(
self,
udf_fields: "Sequence[str]",
Expand All @@ -290,16 +303,22 @@ def run(
) -> Iterator[Iterable[UDFResult]]:
self.catalog = catalog
self.setup()

for row in udf_inputs:
id_, *udf_args = self._prepare_row_and_id(
row, udf_fields, cache, download_cb
)
result_objs = self.process_safe(udf_args)
udf_output = self._flatten_row(result_objs)
output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))]
processed_cb.relative_update(1)
yield output
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
self._prepare_row_and_id(row, udf_fields, cache, download_cb)
for row in udf_inputs
)
if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()

with contextlib.closing(prepared_inputs):
for id_, *udf_args in prepared_inputs:
result_objs = self.process_safe(udf_args)
udf_output = self._flatten_row(result_objs)
output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))]
processed_cb.relative_update(1)
yield output

self.teardown()

Expand Down Expand Up @@ -349,6 +368,7 @@ class Generator(UDFBase):
"""Inherit from this class to pass to `DataChain.gen()`."""

is_output_batched = True
prefetch: int = 2

def run(
self,
Expand All @@ -361,14 +381,21 @@ def run(
) -> Iterator[Iterable[UDFResult]]:
self.catalog = catalog
self.setup()

for row in udf_inputs:
udf_args = self._prepare_row(row, udf_fields, cache, download_cb)
result_objs = self.process_safe(udf_args)
udf_outputs = (self._flatten_row(row) for row in result_objs)
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
processed_cb.relative_update(1)
yield output
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs
)
if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()

with contextlib.closing(prepared_inputs):
for row in prepared_inputs:
result_objs = self.process_safe(row)
udf_outputs = (self._flatten_row(row) for row in result_objs)
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
processed_cb.relative_update(1)
yield output

self.teardown()

Expand Down
52 changes: 25 additions & 27 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,33 +472,31 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
# Otherwise process single-threaded (faster for smaller UDFs)
warehouse = self.catalog.warehouse

with contextlib.closing(
batching(warehouse.dataset_select_paginated, query)
) as udf_inputs:
download_cb = get_download_callback()
processed_cb = get_processed_callback()
generated_cb = get_generated_callback(self.is_generator)
try:
udf_results = self.udf.run(
udf_fields,
udf_inputs,
self.catalog,
self.is_generator,
self.cache,
download_cb,
processed_cb,
)
process_udf_outputs(
warehouse,
udf_table,
udf_results,
self.udf,
cb=generated_cb,
)
finally:
download_cb.close()
processed_cb.close()
generated_cb.close()
udf_inputs = batching(warehouse.dataset_select_paginated, query)
download_cb = get_download_callback()
processed_cb = get_processed_callback()
generated_cb = get_generated_callback(self.is_generator)
try:
udf_results = self.udf.run(
udf_fields,
udf_inputs,
self.catalog,
self.is_generator,
self.cache,
download_cb,
processed_cb,
)
process_udf_outputs(
warehouse,
udf_table,
udf_results,
self.udf,
cb=generated_cb,
)
finally:
download_cb.close()
processed_cb.close()
generated_cb.close()

warehouse.insert_rows_done(udf_table)

Expand Down
7 changes: 5 additions & 2 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,20 @@ def test_from_storage_dependencies(cloud_test_catalog, cloud_type):


@pytest.mark.parametrize("use_cache", [True, False])
def test_map_file(cloud_test_catalog, use_cache):
@pytest.mark.parametrize("prefetch", [0, 2])
def test_map_file(cloud_test_catalog, use_cache, prefetch):
ctc = cloud_test_catalog

def new_signal(file: File) -> str:
assert bool(file.get_local_path()) is (use_cache and prefetch > 0)
with file.open() as f:
return file.name + " -> " + f.read().decode("utf-8")

dc = (
DataChain.from_storage(ctc.src_uri, session=ctc.session)
.settings(cache=use_cache)
.settings(cache=use_cache, prefetch=prefetch)
.map(signal=new_signal)
.save()
)

expected = {
Expand Down

0 comments on commit d1d1457

Please sign in to comment.