Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify datachain.lib.listing by reusing Cilent.scandir() #376

Merged
merged 1 commit into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 6 additions & 93 deletions src/datachain/lib/listing.py
Original file line number Diff line number Diff line change
@@ -1,103 +1,14 @@
import asyncio
from collections.abc import AsyncIterator, Iterator, Sequence
from typing import Callable, Optional
from collections.abc import Iterator
from typing import Callable

from botocore.exceptions import ClientError
from fsspec.asyn import get_loop

from datachain.asyn import iter_over_async
from datachain.client import Client
from datachain.error import ClientError as DataChainClientError
from datachain.lib.file import File

ResultQueue = asyncio.Queue[Optional[Sequence[File]]]

DELIMITER = "/" # Path delimiter
FETCH_WORKERS = 100


async def _fetch_dir(client, prefix, result_queue) -> set[str]:
path = f"{client.name}/{prefix}"
infos = await client.ls_dir(path)
files = []
subdirs = set()
for info in infos:
full_path = info["name"]
subprefix = client.rel_path(full_path)
if prefix.strip(DELIMITER) == subprefix.strip(DELIMITER):
continue
if info["type"] == "directory":
subdirs.add(subprefix)
else:
files.append(client.info_to_file(info, subprefix))
if files:
await result_queue.put(files)
return subdirs


async def _fetch(
client, start_prefix: str, result_queue: ResultQueue, fetch_workers
) -> None:
loop = get_loop()

queue: asyncio.Queue[str] = asyncio.Queue()
queue.put_nowait(start_prefix)

async def process(queue) -> None:
while True:
prefix = await queue.get()
try:
subdirs = await _fetch_dir(client, prefix, result_queue)
for subdir in subdirs:
queue.put_nowait(subdir)
except Exception:
while not queue.empty():
queue.get_nowait()
queue.task_done()
raise

finally:
queue.task_done()

try:
workers: list[asyncio.Task] = [
loop.create_task(process(queue)) for _ in range(fetch_workers)
]

# Wait for all fetch tasks to complete
await queue.join()
# Stop the workers
excs = []
for worker in workers:
if worker.done() and (exc := worker.exception()):
excs.append(exc)
else:
worker.cancel()
if excs:
raise excs[0]
except ClientError as exc:
raise DataChainClientError(
exc.response.get("Error", {}).get("Message") or exc,
exc.response.get("Error", {}).get("Code"),
) from exc
finally:
# This ensures the progress bar is closed before any exceptions are raised
result_queue.put_nowait(None)


async def _scandir(client, prefix, fetch_workers) -> AsyncIterator:
"""Recursively goes through dir tree and yields files"""
result_queue: ResultQueue = asyncio.Queue()
loop = get_loop()
main_task = loop.create_task(_fetch(client, prefix, result_queue, fetch_workers))
while (files := await result_queue.get()) is not None:
for f in files:
yield f

await main_task


def list_bucket(uri: str, client_config=None, fetch_workers=FETCH_WORKERS) -> Callable:
def list_bucket(uri: str, client_config=None) -> Callable:
"""
Function that returns another generator function that yields File objects
from bucket where each File represents one bucket entry.
Expand All @@ -106,6 +17,8 @@ def list_bucket(uri: str, client_config=None, fetch_workers=FETCH_WORKERS) -> Ca
def list_func() -> Iterator[File]:
config = client_config or {}
client, path = Client.parse_url(uri, None, **config) # type: ignore[arg-type]
yield from iter_over_async(_scandir(client, path, fetch_workers), get_loop())
for entries in iter_over_async(client.scandir(path), get_loop()):
for entry in entries:
yield entry.to_file(client.uri)

return list_func
13 changes: 13 additions & 0 deletions src/datachain/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import attrs

from datachain.cache import UniqueId
from datachain.lib.file import File
from datachain.storage import StorageURI
from datachain.utils import TIME_ZERO, time_to_str

Expand Down Expand Up @@ -189,6 +190,18 @@ def parent(self):
return ""
return split[0]

def to_file(self, source: str) -> File:
return File(
source=source,
path=self.path,
size=self.size,
version=self.version,
etag=self.etag,
is_latest=self.is_latest,
last_modified=self.last_modified,
location=self.location,
)


def get_path(parent: str, name: str):
return f"{parent}/{name}" if parent else name
Expand Down