Skip to content

Commit

Permalink
fix(cache, client): support versioned files
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein committed Dec 27, 2024
1 parent f29dac7 commit f480f38
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 27 deletions.
6 changes: 4 additions & 2 deletions src/datachain/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,16 @@ async def download(
tmp_info = odb_fs.join(self.odb.tmp_dir, tmp_fname()) # type: ignore[arg-type]
size = file.size
if size < 0:
size = await client.get_size(from_path)
size = await client.get_size(from_path, version_id=file.version)

Check warning on line 64 in src/datachain/cache.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/cache.py#L64

Added line #L64 was not covered by tests
cb = callback or TqdmCallback(
tqdm_kwargs={"desc": odb_fs.name(from_path), "bytes": True},
tqdm_cls=Tqdm,
size=size,
)
try:
await client.get_file(from_path, tmp_info, callback=cb)
await client.get_file(
from_path, tmp_info, callback=cb, version_id=file.version
)
finally:
if not callback:
cb.close()
Expand Down
46 changes: 36 additions & 10 deletions src/datachain/client/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,18 @@ def create_fs(cls, **kwargs) -> "AbstractFileSystem":
fs.invalidate_cache()
return fs

@classmethod
def split_version(cls, path: str) -> tuple[str, Optional[str]]:
return path, None

Check warning on line 142 in src/datachain/client/fsspec.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/client/fsspec.py#L142

Added line #L142 was not covered by tests

@classmethod
def join_version(cls, path: str, version_id: Optional[str]) -> str:
return path

Check warning on line 146 in src/datachain/client/fsspec.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/client/fsspec.py#L146

Added line #L146 was not covered by tests

@classmethod
def version_path(cls, path: str, version_id: Optional[str]) -> str:
return path

@classmethod
def from_name(
cls,
Expand Down Expand Up @@ -201,18 +213,30 @@ def url(self, path: str, expires: int = 3600, **kwargs) -> str:
return self.fs.sign(self.get_full_path(path), expiration=expires, **kwargs)

async def get_current_etag(self, file: "File") -> str:
info = await self.fs._info(self.get_full_path(file.path))
kwargs = {}
if self.fs.version_aware:
kwargs["version_id"] = file.version
info = await self.fs._info(
self.get_full_path(file.path, file.version), **kwargs
)
return self.info_to_file(info, "").etag

def get_file_info(self, path: str) -> "File":
info = self.fs.info(self.get_full_path(path))
def get_file_info(self, path: str, version_id: Optional[str] = None) -> "File":
info = self.fs.info(self.get_full_path(path, version_id), version_id=version_id)
return self.info_to_file(info, path)

async def get_size(self, path: str) -> int:
return await self.fs._size(path)
async def get_size(self, path: str, version_id: Optional[str] = None) -> int:
return await self.fs._size(

Check warning on line 229 in src/datachain/client/fsspec.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/client/fsspec.py#L229

Added line #L229 was not covered by tests
self.version_path(path, version_id), version_id=version_id
)

async def get_file(self, lpath, rpath, callback):
return await self.fs._get_file(lpath, rpath, callback=callback)
async def get_file(self, lpath, rpath, callback, version_id: Optional[str] = None):
return await self.fs._get_file(
self.version_path(lpath, version_id),
rpath,
callback=callback,
version_id=version_id,
)

async def scandir(
self, start_prefix: str, method: str = "default"
Expand Down Expand Up @@ -319,8 +343,8 @@ async def ls_dir(self, path):
def rel_path(self, path: str) -> str:
return self.fs.split_path(path)[1]

def get_full_path(self, rel_path: str) -> str:
return f"{self.PREFIX}{self.name}/{rel_path}"
def get_full_path(self, rel_path: str, version_id: Optional[str] = None) -> str:
return self.version_path(f"{self.PREFIX}{self.name}/{rel_path}", version_id)

@abstractmethod
def info_to_file(self, v: dict[str, Any], parent: str) -> "File": ...
Expand Down Expand Up @@ -366,7 +390,9 @@ def open_object(
if use_cache and (cache_path := self.cache.get_path(file)):
return open(cache_path, mode="rb")
assert not file.location
return FileWrapper(self.fs.open(self.get_full_path(file.path)), cb) # type: ignore[return-value]
return FileWrapper(
self.fs.open(self.get_full_path(file.path, file.version)), cb
) # type: ignore[return-value]

def download(self, file: "File", *, callback: Callback = DEFAULT_CALLBACK) -> None:
sync(get_loop(), functools.partial(self._download, file, callback=callback))
Expand Down
24 changes: 24 additions & 0 deletions src/datachain/client/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Iterable
from datetime import datetime
from typing import Any, Optional, cast
from urllib.parse import urlsplit

from dateutil.parser import isoparse
from gcsfs import GCSFileSystem
Expand Down Expand Up @@ -130,3 +131,26 @@ def info_to_file(self, v: dict[str, Any], path: str) -> File:
last_modified=self.parse_timestamp(v["updated"]),
size=v.get("size", ""),
)

@classmethod
def split_version(cls, path: str) -> tuple[str, Optional[str]]:
parts = list(urlsplit(path))
scheme = parts[0]
parts = GCSFileSystem._split_path( # pylint: disable=protected-access
path, version_aware=True
)
bucket, key, generation = parts
scheme = f"{scheme}://" if scheme else ""
return f"{scheme}{bucket}/{key}", generation

@classmethod
def join_version(cls, path: str, version_id: Optional[str]) -> str:
path, path_version = cls.split_version(path)
if path_version:
raise ValueError("path already includes an object generation")

Check warning on line 150 in src/datachain/client/gcs.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/client/gcs.py#L150

Added line #L150 was not covered by tests
return f"{path}#{version_id}" if version_id else path

@classmethod
def version_path(cls, path: str, version_id: Optional[str]) -> str:
path, _ = cls.split_version(path)
return cls.join_version(path, version_id)
8 changes: 4 additions & 4 deletions src/datachain/client/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import posixpath
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
from urllib.parse import urlparse

from fsspec.implementations.local import LocalFileSystem
Expand Down Expand Up @@ -105,10 +105,10 @@ async def get_current_etag(self, file: "File") -> str:
info = self.fs.info(self.get_full_path(file.path))
return self.info_to_file(info, "").etag

async def get_size(self, path: str) -> int:
async def get_size(self, path: str, version_id: Optional[str] = None) -> int:
return self.fs.size(path)

async def get_file(self, lpath, rpath, callback):
async def get_file(self, lpath, rpath, callback, version_id: Optional[str] = None):
return self.fs.get_file(lpath, rpath, callback=callback)

async def ls_dir(self, path):
Expand All @@ -117,7 +117,7 @@ async def ls_dir(self, path):
def rel_path(self, path):
return posixpath.relpath(path, self.name)

def get_full_path(self, rel_path):
def get_full_path(self, rel_path, version_id: Optional[str] = None):
full_path = Path(self.name, rel_path).as_posix()
if rel_path.endswith("/") or not rel_path:
full_path += "/"
Expand Down
28 changes: 17 additions & 11 deletions tests/func/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,18 +461,18 @@ def test_cp_single_file(cloud_test_catalog, no_glob):
assert tree_from_path(dest) == {"local_dog": "woof"}


@pytest.mark.parametrize("tree", [{"foo": "original"}], indirect=True)
@pytest.mark.parametrize("tree", [{"bar-file": "original"}], indirect=True)
def test_cp_file_storage_mutation(cloud_test_catalog):
working_dir = cloud_test_catalog.working_dir
catalog = cloud_test_catalog.catalog
src_path = f"{cloud_test_catalog.src_uri}/foo"
src_path = f"{cloud_test_catalog.src_uri}/bar-file"

dest = working_dir / "data1"
dest.mkdir()
catalog.cp([src_path], str(dest / "local"), no_edatachain_file=True)
assert tree_from_path(dest) == {"local": "original"}

(cloud_test_catalog.src / "foo").write_text("modified")
(cloud_test_catalog.src / "bar-file").write_text("modified")
dest = working_dir / "data2"
dest.mkdir()
catalog.cp([src_path], str(dest / "local"), no_edatachain_file=True)
Expand All @@ -493,22 +493,22 @@ def test_cp_file_storage_mutation(cloud_test_catalog):
assert tree_from_path(dest) == {"local": "modified"}


@pytest.mark.parametrize("tree", [{"foo-dir": "original"}], indirect=True)
def test_cp_dir_storage_mutation(cloud_test_catalog):
@pytest.mark.parametrize("tree", [{"foo-file": "original"}], indirect=True)
def test_cp_dir_storage_mutation(cloud_test_catalog, version_aware):
working_dir = cloud_test_catalog.working_dir
catalog = cloud_test_catalog.catalog
src_path = f"{cloud_test_catalog.src_uri}/"

dest = working_dir / "data1"
dest.mkdir()
catalog.cp([src_path], str(dest / "local"), no_edatachain_file=True, recursive=True)
assert tree_from_path(dest) == {"local": {"foo-dir": "original"}}
assert tree_from_path(dest) == {"local": {"foo-file": "original"}}

(cloud_test_catalog.src / "foo-dir").write_text("modified")
(cloud_test_catalog.src / "foo-file").write_text("modified")
dest = working_dir / "data2"
dest.mkdir()
catalog.cp([src_path], str(dest / "local"), no_edatachain_file=True, recursive=True)
assert tree_from_path(dest) == {"local": {"foo-dir": "original"}}
assert tree_from_path(dest) == {"local": {"foo-file": "original"}}

# For a dir we access files through listing
# so it finds a etag for the origin file, but it's now not in cache + it
Expand All @@ -517,17 +517,23 @@ def test_cp_dir_storage_mutation(cloud_test_catalog):
catalog.cache.clear()
dest = working_dir / "data3"
dest.mkdir()
with pytest.raises(FileNotFoundError):
if version_aware:
catalog.cp(
[src_path], str(dest / "local"), no_edatachain_file=True, recursive=True
)
assert tree_from_path(dest) == {"local": {}}
assert tree_from_path(dest) == {"local": {"foo-file": "original"}}
else:
with pytest.raises(FileNotFoundError):
catalog.cp(
[src_path], str(dest / "local"), no_edatachain_file=True, recursive=True
)
assert tree_from_path(dest) == {"local": {}}

catalog.index([src_path], update=True)
dest = working_dir / "data4"
dest.mkdir()
catalog.cp([src_path], str(dest / "local"), no_edatachain_file=True, recursive=True)
assert tree_from_path(dest) == {"local": {"foo-dir": "modified"}}
assert tree_from_path(dest) == {"local": {"foo-file": "modified"}}


def test_cp_edatachain_file_options(cloud_test_catalog):
Expand Down
1 change: 1 addition & 0 deletions tests/func/test_pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _adapt_row(row):
adapted["sys__rand"] = 1
adapted["file__location"] = ""
adapted["file__source"] = src_uri
adapted["file__version"] = ""
return adapted

dog_entries = [_adapt_row(e) for e in dog_entries]
Expand Down

0 comments on commit f480f38

Please sign in to comment.