Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial Sharding Prototype #1

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions chunking_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json
import os

import zarr

store = zarr.DirectoryStore("data/chunking_test.zarr")
z = zarr.zeros((20, 3), chunks=(3, 3), shards=(2, 2), store=store, overwrite=True, compressor=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shards is specified in units of chunks?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes 👍

z[...] = 42
z[15, 1] = 389
z[19, 2] = 1
z[0, 1] = -4.2

print("ONDISK", sorted(os.listdir("data/chunking_test.zarr")))
assert json.loads(store[".zarray"].decode()) ["shards"] == [2, 2]

print("STORE", list(store))
print("CHUNKSTORE (SHARDED)", list(z.chunk_store))

z_reopened = zarr.open("data/chunking_test.zarr")
assert z_reopened.shards == (2, 2)
assert z_reopened[15, 1] == 389
assert z_reopened[19, 2] == 1
assert z_reopened[0, 1] == -4.2
assert z_reopened[0, 0] == 42
102 changes: 102 additions & 0 deletions zarr/_storage/sharded_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from functools import reduce
from itertools import product
from typing import Any, Iterable, Iterator, Optional, Tuple

from zarr._storage.store import BaseStore, Store
from zarr.storage import StoreLike, array_meta_key, attrs_key, group_meta_key


def _cum_prod(x: Iterable[int]) -> Iterable[int]:
prod = 1
yield prod
for i in x[:-1]:
prod *= i
yield prod


class ShardedStore(Store):
"""This class should not be used directly,
but is added to an Array as a wrapper when needed automatically."""

def __init__(
self, store:
StoreLike,
jstriebel marked this conversation as resolved.
Show resolved Hide resolved
shards: Tuple[int, ...],
dimension_separator: str,
chunk_has_constant_size: bool,
fill_value: bytes,
value_len: Optional[int],
) -> None:
self._store: BaseStore = BaseStore._ensure_store(store)
self._shards = shards
# This defines C/F-order
self._shards_cumprod = tuple(_cum_prod(shards))
jstriebel marked this conversation as resolved.
Show resolved Hide resolved
self._num_chunks_per_shard = reduce(lambda x, y: x*y, shards, 1)
self._dimension_separator = dimension_separator
# TODO: add jumptable for compressed data
assert not chunk_has_constant_size, "Currently only uncompressed data can be used."
self._chunk_has_constant_size = chunk_has_constant_size
if not chunk_has_constant_size:
assert value_len is not None
self._fill_chunk = fill_value * value_len
else:
self._fill_chunk = None

# TODO: add warnings for ineffective reads/writes:
# * warn if partial reads are not available
# * optionally warn on unaligned writes if no partial writes are available

def __key_to_sharded__(self, key: str) -> Tuple[str, int]:
# TODO: allow to be in a group (aka only use last parts for dimensions)
subkeys = map(int, key.split(self._dimension_separator))
jstriebel marked this conversation as resolved.
Show resolved Hide resolved

shard_tuple, index_tuple = zip(*((subkey // shard_i, subkey % shard_i) for subkey, shard_i in zip(subkeys, self._shards)))
shard_key = self._dimension_separator.join(map(str, shard_tuple))
index = sum(i * j for i, j in zip(index_tuple, self._shards_cumprod))
return shard_key, index

def __get_chunk_slice__(self, shard_key: str, shard_index: int) -> Tuple[int, int]:
jstriebel marked this conversation as resolved.
Show resolved Hide resolved
# TODO: here we would use the jumptable for compression
start = shard_index * len(self._fill_chunk)
return slice(start, start + len(self._fill_chunk))

def __getitem__(self, key: str) -> bytes:
shard_key, shard_index = self.__key_to_sharded__(key)
chunk_slice = self.__get_chunk_slice__(shard_key, shard_index)
# TODO use partial reads if available
full_shard_value = self._store[shard_key]
return full_shard_value[chunk_slice]

def __setitem__(self, key: str, value: bytes) -> None:
shard_key, shard_index = self.__key_to_sharded__(key)
if shard_key in self._store:
full_shard_value = bytearray(self._store[shard_key])
else:
full_shard_value = bytearray(self._fill_chunk * self._num_chunks_per_shard)
chunk_slice = self.__get_chunk_slice__(shard_key, shard_index)
# TODO use partial writes if available
full_shard_value[chunk_slice] = value
self._store[shard_key] = full_shard_value

def __delitem__(self, key) -> None:
# TODO not implemented yet
# For uncompressed chunks, deleting the "last" chunk might need to be detected.
raise NotImplementedError("Deletion is not yet implemented")

def __iter__(self) -> Iterator[str]:
for shard_key in self._store.__iter__():
if any(shard_key.endswith(i) for i in (array_meta_key, group_meta_key, attrs_key)):
yield shard_key
else:
# TODO: allow to be in a group (aka only use last parts for dimensions)
subkeys = tuple(map(int, shard_key.split(self._dimension_separator)))
for offset in product(*(range(i) for i in self._shards)):
original_key = (subkeys_i * shards_i + offset_i for subkeys_i, offset_i, shards_i in zip(subkeys, offset, self._shards))
yield self._dimension_separator.join(map(str, original_key))
jstriebel marked this conversation as resolved.
Show resolved Hide resolved

def __len__(self) -> int:
return sum(1 for _ in self.keys())

# TODO: For efficient reads and writes, we need to implement
# getitems, setitems & delitems
# and combine writes/reads/deletions to the same shard.
1 change: 1 addition & 0 deletions zarr/_storage/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _ensure_store(store: Any):


class Store(BaseStore):
# TODO: document methods which allow optimizations, e.g. delitems, setitems, getitems, listdir, …
"""Abstract store class used by implementations following the Zarr v2 spec.

Adds public `listdir`, `rename`, and `rmdir` methods on top of BaseStore.
Expand Down
37 changes: 32 additions & 5 deletions zarr/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from numcodecs.compat import ensure_bytes, ensure_ndarray

from collections.abc import MutableMapping
from zarr._storage.sharded_store import ShardedStore

from zarr.attrs import Attributes
from zarr.codecs import AsType, get_codec
Expand Down Expand Up @@ -213,6 +214,7 @@ def _load_metadata_nosync(self):
self._meta = meta
self._shape = meta['shape']
self._chunks = meta['chunks']
self._shards = meta.get('shards')
self._dtype = meta['dtype']
self._fill_value = meta['fill_value']
self._order = meta['order']
Expand Down Expand Up @@ -264,7 +266,9 @@ def _flush_metadata_nosync(self):
filters_config = None
meta = dict(shape=self._shape, chunks=self._chunks, dtype=self._dtype,
compressor=compressor_config, fill_value=self._fill_value,
order=self._order, filters=filters_config)
order=self._order, filters=filters_config, shards=self._shards)
if self._shards is not None:
meta['shards'] = self._shards
jstriebel marked this conversation as resolved.
Show resolved Hide resolved
mkey = self._key_prefix + array_meta_key
self._store[mkey] = self._store._metadata_class.encode_array_metadata(meta)

Expand Down Expand Up @@ -307,11 +311,26 @@ def read_only(self, value):

@property
def chunk_store(self):
"""A MutableMapping providing the underlying storage for array chunks."""
if self._chunk_store is None:
return self._store
chunk_store = self._store
else:
chunk_store = self._chunk_store
"""A MutableMapping providing the underlying storage for array chunks."""
jstriebel marked this conversation as resolved.
Show resolved Hide resolved
if self._shards is None:
return chunk_store
else:
return self._chunk_store
try:
return self._cached_sharded_store
jstriebel marked this conversation as resolved.
Show resolved Hide resolved
except AttributeError:
self._cached_sharded_store = BaseStore._ensure_store(ShardedStore(
chunk_store,
shards=self._shards,
dimension_separator=self._dimension_separator,
chunk_has_constant_size = self._compressor is not None, # TODO add exceptions, e.g. dtype==object
fill_value = np.full(1, fill_value=self._fill_value or 0, dtype=self._dtype).tobytes(),
jstriebel marked this conversation as resolved.
Show resolved Hide resolved
value_len = reduce(operator.mul, self._chunks, 1),
))
return self._cached_sharded_store

@property
def shape(self):
Expand All @@ -332,6 +351,12 @@ def chunks(self):
chunk of the array."""
return self._chunks

@property
def shards(self):
"""A tuple of integers describing the number of chunks in each shard
jstriebel marked this conversation as resolved.
Show resolved Hide resolved
of the array."""
return self._shards

@property
def dtype(self):
"""The NumPy data type."""
Expand Down Expand Up @@ -1899,7 +1924,7 @@ def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection,
and hasattr(self._compressor, "decode_partial")
and not fields
and self.dtype != object
and hasattr(self.chunk_store, "getitems")
and hasattr(self.chunk_store, "getitems") # TODO: this should rather check for read_block or similar
):
partial_read_decode = True
cdatas = {
Expand Down Expand Up @@ -2236,6 +2261,7 @@ def digest(self, hashname="sha1"):

h = hashlib.new(hashname)

# TODO: operate on shards here if available:
for i in itertools.product(*[range(s) for s in self.cdata_shape]):
h.update(self.chunk_store.get(self._chunk_key(i), b""))

Expand Down Expand Up @@ -2362,6 +2388,7 @@ def _resize_nosync(self, *args):
except KeyError:
# chunk not initialized
pass
# TODO: collect all chunks do delete and use _chunk_delitems

def append(self, data, axis=0):
"""Append `data` to `axis`.
Expand Down
6 changes: 4 additions & 2 deletions zarr/creation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional, Tuple, Union
from warnings import warn

import numpy as np
Expand All @@ -19,7 +20,8 @@ def create(shape, chunks=True, dtype=None, compressor='default',
fill_value=0, order='C', store=None, synchronizer=None,
overwrite=False, path=None, chunk_store=None, filters=None,
cache_metadata=True, cache_attrs=True, read_only=False,
object_codec=None, dimension_separator=None, write_empty_chunks=True, **kwargs):
object_codec=None, dimension_separator=None, write_empty_chunks=True,
shards: Union[int, Tuple[int, ...], None]=None, **kwargs):
"""Create an array.

Parameters
Expand Down Expand Up @@ -145,7 +147,7 @@ def create(shape, chunks=True, dtype=None, compressor='default',
init_array(store, shape=shape, chunks=chunks, dtype=dtype, compressor=compressor,
fill_value=fill_value, order=order, overwrite=overwrite, path=path,
chunk_store=chunk_store, filters=filters, object_codec=object_codec,
dimension_separator=dimension_separator)
dimension_separator=dimension_separator, shards=shards)

# instantiate array
z = Array(store, path=path, chunk_store=chunk_store, synchronizer=synchronizer,
Expand Down
8 changes: 6 additions & 2 deletions zarr/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def decode_array_metadata(cls, s: Union[MappingType, str]) -> MappingType[str, A
object_codec = None

dimension_separator = meta.get("dimension_separator", None)
shards = meta.get("shards", None)
fill_value = cls.decode_fill_value(meta['fill_value'], dtype, object_codec)
meta = dict(
zarr_format=meta["zarr_format"],
Expand All @@ -64,6 +65,8 @@ def decode_array_metadata(cls, s: Union[MappingType, str]) -> MappingType[str, A
)
if dimension_separator:
meta['dimension_separator'] = dimension_separator
if shards:
meta['shards'] = tuple(shards)
except Exception as e:
raise MetadataError("error decoding metadata") from e
else:
Expand All @@ -77,6 +80,7 @@ def encode_array_metadata(cls, meta: MappingType[str, Any]) -> bytes:
dtype, sdshape = dtype.subdtype

dimension_separator = meta.get("dimension_separator")
shards = meta.get("shards")
if dtype.hasobject:
import numcodecs
object_codec = numcodecs.get_codec(meta['filters'][0])
Expand All @@ -96,8 +100,8 @@ def encode_array_metadata(cls, meta: MappingType[str, Any]) -> bytes:
if dimension_separator:
meta['dimension_separator'] = dimension_separator

if dimension_separator:
meta["dimension_separator"] = dimension_separator
if shards:
meta['shards'] = shards

return json_dumps(meta)

Expand Down
10 changes: 8 additions & 2 deletions zarr/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from zarr.util import (buffer_size, json_loads, nolock, normalize_chunks,
normalize_dimension_separator,
normalize_dtype, normalize_fill_value, normalize_order,
normalize_shape, normalize_storage_path, retry_call)
normalize_shape, normalize_shards, normalize_storage_path, retry_call)

from zarr._storage.absstore import ABSStore # noqa: F401
from zarr._storage.store import (_listdir_from_keys,
Expand Down Expand Up @@ -236,6 +236,7 @@ def init_array(
filters=None,
object_codec=None,
dimension_separator=None,
shards: Union[int, Tuple[int, ...], None]=None,
):
"""Initialize an array store with the given configuration. Note that this is a low-level
function and there should be no need to call this directly from user code.
Expand Down Expand Up @@ -353,7 +354,8 @@ def init_array(
order=order, overwrite=overwrite, path=path,
chunk_store=chunk_store, filters=filters,
object_codec=object_codec,
dimension_separator=dimension_separator)
dimension_separator=dimension_separator,
shards=shards)


def _init_array_metadata(
Expand All @@ -370,6 +372,7 @@ def _init_array_metadata(
filters=None,
object_codec=None,
dimension_separator=None,
shards:Union[int, Tuple[int, ...], None] = None,
):

# guard conditions
Expand All @@ -388,6 +391,7 @@ def _init_array_metadata(
shape = normalize_shape(shape) + dtype.shape
dtype = dtype.base
chunks = normalize_chunks(chunks, shape, dtype.itemsize)
shards = normalize_shards(shards, shape)
order = normalize_order(order)
fill_value = normalize_fill_value(fill_value, dtype)

Expand Down Expand Up @@ -445,6 +449,8 @@ def _init_array_metadata(
compressor=compressor_config, fill_value=fill_value,
order=order, filters=filters_config,
dimension_separator=dimension_separator)
if shards is not None:
meta["shards"] = shards
key = _path_to_prefix(path) + array_meta_key
if hasattr(store, '_metadata_class'):
store[key] = store._metadata_class.encode_array_metadata(meta) # type: ignore
Expand Down
33 changes: 33 additions & 0 deletions zarr/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,38 @@ def normalize_chunks(
return tuple(chunks)


def normalize_shards(
shards: Optional[Tuple[int, ...]], shape: Tuple[int, ...],
) -> Tuple[int, ...]:
jstriebel marked this conversation as resolved.
Show resolved Hide resolved
"""Convenience function to normalize the `shards` argument for an array
with the given `shape`."""

# N.B., expect shape already normalized

if shards is None:
return None

# handle 1D convenience form
if isinstance(shards, numbers.Integral):
shards = tuple(int(shards) for _ in shape)

# handle bad dimensionality
if len(shards) > len(shape):
raise ValueError('too many dimensions in shards')

# handle underspecified shards
if len(shards) < len(shape):
# assume single shards across remaining dimensions
shards += (1, ) * len(shape) - len(shards)

# handle None or -1 in shards
if -1 in shards or None in shards:
shards = tuple(s if c == -1 or c is None else int(c)
for s, c in zip(shape, shards))

return tuple(shards)


def normalize_dtype(dtype: Union[str, np.dtype], object_codec) -> Tuple[np.dtype, Any]:

# convenience API for object arrays
Expand Down Expand Up @@ -560,6 +592,7 @@ def __init__(self, store_key, chunk_store):
# is it fsstore or an actual fsspec map object
assert hasattr(self.chunk_store, "map")
self.map = self.chunk_store.map
# TODO maybe use partial_read here also
self.fs = self.chunk_store.fs
self.store_key = store_key
self.buff = None
Expand Down