Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(store): list_* -> AsyncGenerators #1844

Merged
merged 5 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ disallow_untyped_calls = false


[tool.pytest.ini_options]
asyncio_mode = "auto"
doctest_optionflags = [
"NORMALIZE_WHITESPACE",
"ELLIPSIS",
Expand Down
15 changes: 8 additions & 7 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import abstractmethod, ABC

from collections.abc import AsyncGenerator
from typing import List, Tuple, Optional


Expand All @@ -24,7 +25,7 @@ async def get(
@abstractmethod
async def get_partial_values(
self, key_ranges: List[Tuple[str, Tuple[int, int]]]
) -> List[bytes]:
) -> List[Optional[bytes]]:
"""Retrieve possibly partial values from given key_ranges.

Parameters
Expand Down Expand Up @@ -106,17 +107,17 @@ def supports_listing(self) -> bool:
...

@abstractmethod
async def list(self) -> List[str]:
def list(self) -> AsyncGenerator[str, None]:
"""Retrieve all keys in the store.

Returns
-------
list[str]
AsyncGenerator[str, None]
"""
...

@abstractmethod
async def list_prefix(self, prefix: str) -> List[str]:
def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
"""Retrieve all keys in the store with a given prefix.

Parameters
Expand All @@ -125,12 +126,12 @@ async def list_prefix(self, prefix: str) -> List[str]:

Returns
-------
list[str]
AsyncGenerator[str, None]
"""
...

@abstractmethod
async def list_dir(self, prefix: str) -> List[str]:
def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
"""
Retrieve all keys and prefixes with a given prefix and which do not contain the character
“/” after the given prefix.
Expand All @@ -141,6 +142,6 @@ async def list_dir(self, prefix: str) -> List[str]:

Returns
-------
list[str]
AsyncGenerator[str, None]
"""
...
16 changes: 8 additions & 8 deletions src/zarr/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,20 +306,20 @@ async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], N
)

raise ValueError(msg)
subkeys = await self.store_path.store.list_dir(self.store_path.path)
# would be nice to make these special keys accessible programmatically,
# and scoped to specific zarr versions
subkeys_filtered = filter(lambda v: v not in ("zarr.json", ".zgroup", ".zattrs"), subkeys)
# is there a better way to schedule this?
for subkey in subkeys_filtered:
_skip_keys = ("zarr.json", ".zgroup", ".zattrs")
async for key in self.store_path.store.list_dir(self.store_path.path):
if key in _skip_keys:
continue
try:
yield (subkey, await self.getitem(subkey))
yield (key, await self.getitem(key))
except KeyError:
# keyerror is raised when `subkey` names an object (in the object storage sense),
# keyerror is raised when `key` names an object (in the object storage sense),
# as opposed to a prefix, in the store under the prefix associated with this group
# in which case `subkey` cannot be the name of a sub-array or sub-group.
# in which case `key` cannot be the name of a sub-array or sub-group.
logger.warning(
"Object at %s is not recognized as a component of a Zarr hierarchy.", subkey
"Object at %s is not recognized as a component of a Zarr hierarchy.", key
)
pass

Expand Down
86 changes: 53 additions & 33 deletions src/zarr/store/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io
import shutil
from collections.abc import AsyncGenerator
from pathlib import Path
from typing import Union, Optional, List, Tuple

Expand All @@ -10,8 +11,24 @@


def _get(path: Path, byte_range: Optional[Tuple[int, Optional[int]]] = None) -> bytes:
"""
Fetch a contiguous region of bytes from a file.
Parameters
----------
path: Path
The file to read bytes from.
byte_range: Optional[Tuple[int, Optional[int]]] = None
The range of bytes to read. If `byte_range` is `None`, then the entire file will be read.
If `byte_range` is a tuple, the first value specifies the index of the first byte to read,
and the second value specifies the total number of bytes to read. If the total value is
`None`, then the entire file after the first byte will be read.
"""
if byte_range is not None:
start = byte_range[0]
if byte_range[0] is None:
start = 0
else:
start = byte_range[0]

end = (start + byte_range[1]) if byte_range[1] is not None else None
else:
return path.read_bytes()
Expand Down Expand Up @@ -84,21 +101,28 @@ async def get(

async def get_partial_values(
self, key_ranges: List[Tuple[str, Tuple[int, int]]]
) -> List[bytes]:
) -> List[Optional[bytes]]:
"""
Read byte ranges from multiple keys.
Parameters
----------
key_ranges: List[Tuple[str, Tuple[int, int]]]
A list of (key, (start, length)) tuples. The first element of the tuple is the name of
the key in storage to fetch bytes from. The second element the tuple defines the byte
range to retrieve. These values are arguments to `get`, as this method wraps
concurrent invocation of `get`.
"""
args = []
for key, byte_range in key_ranges:
assert isinstance(key, str)
path = self.root / key
if byte_range is not None:
args.append((_get, path, byte_range[0], byte_range[1]))
else:
args.append((_get, path))
args.append((_get, path, byte_range))
return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit

async def set(self, key: str, value: BytesLike) -> None:
assert isinstance(key, str)
path = self.root / key
await to_thread(_put, path, value)
await to_thread(_put, path, value, auto_mkdir=self.auto_mkdir)

async def set_partial_values(self, key_start_values: List[Tuple[str, int, bytes]]) -> None:
args = []
Expand All @@ -122,22 +146,19 @@ async def exists(self, key: str) -> bool:
path = self.root / key
return await to_thread(path.is_file)

async def list(self) -> List[str]:
async def list(self) -> AsyncGenerator[str, None]:
"""Retrieve all keys in the store.

Returns
-------
list[str]
AsyncGenerator[str, None]
"""
to_strip = str(self.root) + "/"
for p in list(self.root.rglob("*")):
if p.is_file():
yield str(p).replace(to_strip, "")

# Q: do we want to return strings or Paths?
def _list(root: Path) -> List[str]:
files = [str(p) for p in root.rglob("") if p.is_file()]
return files

return await to_thread(_list, self.root)

async def list_prefix(self, prefix: str) -> List[str]:
async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
"""Retrieve all keys in the store with a given prefix.

Parameters
Expand All @@ -146,16 +167,15 @@ async def list_prefix(self, prefix: str) -> List[str]:

Returns
-------
list[str]
AsyncGenerator[str, None]
"""

def _list_prefix(root: Path, prefix: str) -> List[str]:
files = [str(p) for p in (root / prefix).rglob("*") if p.is_file()]
return files
to_strip = str(self.root) + "/"
for p in (self.root / prefix).rglob("*"):
if p.is_file():
yield str(p).replace(to_strip, "")

return await to_thread(_list_prefix, self.root, prefix)

async def list_dir(self, prefix: str) -> List[str]:
async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
"""
Retrieve all keys and prefixes with a given prefix and which do not contain the character
“/” after the given prefix.
Expand All @@ -166,15 +186,15 @@ async def list_dir(self, prefix: str) -> List[str]:

Returns
-------
list[str]
AsyncGenerator[str, None]
"""

def _list_dir(root: Path, prefix: str) -> List[str]:
base = root / prefix
to_strip = str(base) + "/"
try:
return [str(key).replace(to_strip, "") for key in base.iterdir()]
except (FileNotFoundError, NotADirectoryError):
return []
base = self.root / prefix
to_strip = str(base) + "/"

return await to_thread(_list_dir, self.root, prefix)
try:
key_iter = base.iterdir()
for key in key_iter:
yield str(key).replace(to_strip, "")
except (FileNotFoundError, NotADirectoryError):
pass
37 changes: 21 additions & 16 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from collections.abc import AsyncGenerator
from typing import Optional, MutableMapping, List, Tuple

from zarr.common import BytesLike
from zarr.common import BytesLike, concurrent_map
from zarr.abc.store import Store


Expand Down Expand Up @@ -38,8 +39,9 @@ async def get(

async def get_partial_values(
self, key_ranges: List[Tuple[str, Tuple[int, int]]]
) -> List[bytes]:
raise NotImplementedError
) -> List[Optional[BytesLike]]:
vals = await concurrent_map(key_ranges, self.get, limit=None)
return vals

async def exists(self, key: str) -> bool:
return key in self._store_dict
Expand Down Expand Up @@ -67,20 +69,23 @@ async def delete(self, key: str) -> None:
async def set_partial_values(self, key_start_values: List[Tuple[str, int, bytes]]) -> None:
raise NotImplementedError

async def list(self) -> List[str]:
return list(self._store_dict.keys())
async def list(self) -> AsyncGenerator[str, None]:
for key in self._store_dict:
yield key

async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
for key in self._store_dict:
if key.startswith(prefix):
yield key

async def list_prefix(self, prefix: str) -> List[str]:
return [key for key in self._store_dict if key.startswith(prefix)]
async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
if prefix.endswith("/"):
prefix = prefix[:-1]

async def list_dir(self, prefix: str) -> List[str]:
if prefix == "":
return list({key.split("/", maxsplit=1)[0] for key in self._store_dict})
for key in self._store_dict:
yield key.split("/", maxsplit=1)[0]
else:
return list(
{
key.strip(prefix + "/").split("/")[0]
for key in self._store_dict
if (key.startswith(prefix + "/") and key != prefix)
}
)
for key in self._store_dict:
if key.startswith(prefix + "/") and key != prefix:
yield key.strip(prefix + "/").split("/")[0]
9 changes: 9 additions & 0 deletions src/zarr/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import importlib.util
import warnings

if importlib.util.find_spec("pytest") is not None:
from zarr.testing.store import StoreTests
else:
warnings.warn("pytest not installed, skipping test suite")

__all__ = ["StoreTests"]
81 changes: 81 additions & 0 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import pytest
Copy link
Member Author

@jhamman jhamman May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea here is that we will provide a test harness for the generic store interface. This doesn't mean we can't have store-specific tests but those can go in our tests directory.

Importing zarr.testing.store will raise an ImportError if pytest is not installed. I think that could be fine but calling it out so others have a chance to comment. I think there are some things we can do to further hide this import if folks object to the current configuration.


from zarr.abc.store import Store


class StoreTests:
store_cls: type[Store]

@pytest.fixture(scope="function")
def store(self) -> Store:
return self.store_cls()

def test_store_type(self, store: Store) -> None:
assert isinstance(store, Store)
assert isinstance(store, self.store_cls)

def test_store_repr(self, store: Store) -> None:
assert repr(store)

def test_store_capabilities(self, store: Store) -> None:
assert store.supports_writes
assert store.supports_partial_writes
assert store.supports_listing

@pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"])
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
async def test_set_get_bytes_roundtrip(self, store: Store, key: str, data: bytes) -> None:
await store.set(key, data)
assert await store.get(key) == data

@pytest.mark.parametrize("key", ["foo/c/0"])
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
async def test_get_partial_values(self, store: Store, key: str, data: bytes) -> None:
# put all of the data
await store.set(key, data)
# read back just part of it
vals = await store.get_partial_values([(key, (0, 2))])
assert vals == [data[0:2]]

# read back multiple parts of it at once
vals = await store.get_partial_values([(key, (0, 2)), (key, (2, 4))])
assert vals == [data[0:2], data[2:4]]

async def test_exists(self, store: Store) -> None:
assert not await store.exists("foo")
await store.set("foo/zarr.json", b"bar")
assert await store.exists("foo/zarr.json")

async def test_delete(self, store: Store) -> None:
await store.set("foo/zarr.json", b"bar")
assert await store.exists("foo/zarr.json")
await store.delete("foo/zarr.json")
assert not await store.exists("foo/zarr.json")

async def test_list(self, store: Store) -> None:
assert [k async for k in store.list()] == []
await store.set("foo/zarr.json", b"bar")
keys = [k async for k in store.list()]
assert keys == ["foo/zarr.json"], keys

expected = ["foo/zarr.json"]
for i in range(10):
key = f"foo/c/{i}"
expected.append(key)
await store.set(f"foo/c/{i}", i.to_bytes(length=3, byteorder="little"))

async def test_list_prefix(self, store: Store) -> None:
# TODO: we currently don't use list_prefix anywhere
pass

async def test_list_dir(self, store: Store) -> None:
assert [k async for k in store.list_dir("")] == []
assert [k async for k in store.list_dir("foo")] == []
await store.set("foo/zarr.json", b"bar")
await store.set("foo/c/1", b"\x01")

keys = [k async for k in store.list_dir("foo")]
assert keys == ["zarr.json", "c"], keys

keys = [k async for k in store.list_dir("foo/")]
assert keys == ["zarr.json", "c"], keys
Loading
Loading