Skip to content

Commit e6ef2b1

Browse files
authored
optimize shard writing (#3561)
* optimize shard writing Writing to sharded arrays was up to 10x slower for largish chunk sizes because the _ShardBuilder object has many calls to np.concatenate. This commit coalesces these into a single concatenate call, and improves write performance by a factor of 10 on the benchmarking script in #3560. Added a new core.Buffer.combine API Resolves #3560 Signed-off-by: Noah D. Brenowitz <nbren12@gmail.com> * remove redundant method Signed-off-by: Noah D. Brenowitz <nbren12@gmail.com> * remove redundant np.asayarray Signed-off-by: Noah D. Brenowitz <nbren12@gmail.com> * clarify ShardBuilder API remove inheritance, hide the index attribute and remove some indirection Signed-off-by: Noah D. Brenowitz <nbren12@gmail.com> * Remove shard builder objects just use dicts Signed-off-by: Noah D. Brenowitz <nbren12@gmail.com> * fix missing chunk case Signed-off-by: Noah D. Brenowitz <nbren12@gmail.com> * add release note --------- Signed-off-by: Noah D. Brenowitz <nbren12@gmail.com>
1 parent b3e9aed commit e6ef2b1

File tree

5 files changed

+87
-146
lines changed

5 files changed

+87
-146
lines changed

changes/3560.bugfix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve write performance to large shards by up to 10x.

src/zarr/codecs/sharding.py

Lines changed: 64 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Iterable, Mapping, MutableMapping
4-
from dataclasses import dataclass, field, replace
4+
from dataclasses import dataclass, replace
55
from enum import Enum
66
from functools import lru_cache
77
from operator import itemgetter
@@ -54,15 +54,15 @@
5454
from zarr.registry import get_ndbuffer_class, get_pipeline_class
5555

5656
if TYPE_CHECKING:
57-
from collections.abc import Awaitable, Callable, Iterator
57+
from collections.abc import Iterator
5858
from typing import Self
5959

6060
from zarr.core.common import JSON
6161
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType
6262

6363
MAX_UINT_64 = 2**64 - 1
64-
ShardMapping = Mapping[tuple[int, ...], Buffer]
65-
ShardMutableMapping = MutableMapping[tuple[int, ...], Buffer]
64+
ShardMapping = Mapping[tuple[int, ...], Buffer | None]
65+
ShardMutableMapping = MutableMapping[tuple[int, ...], Buffer | None]
6666

6767

6868
class ShardingCodecIndexLocation(Enum):
@@ -219,114 +219,6 @@ def __len__(self) -> int:
219219
def __iter__(self) -> Iterator[tuple[int, ...]]:
220220
return c_order_iter(self.index.offsets_and_lengths.shape[:-1])
221221

222-
def is_empty(self) -> bool:
223-
return self.index.is_all_empty()
224-
225-
226-
class _ShardBuilder(_ShardReader, ShardMutableMapping):
227-
buf: Buffer
228-
index: _ShardIndex
229-
230-
@classmethod
231-
def merge_with_morton_order(
232-
cls,
233-
chunks_per_shard: tuple[int, ...],
234-
tombstones: set[tuple[int, ...]],
235-
*shard_dicts: ShardMapping,
236-
) -> _ShardBuilder:
237-
obj = cls.create_empty(chunks_per_shard)
238-
for chunk_coords in morton_order_iter(chunks_per_shard):
239-
if chunk_coords in tombstones:
240-
continue
241-
for shard_dict in shard_dicts:
242-
maybe_value = shard_dict.get(chunk_coords, None)
243-
if maybe_value is not None:
244-
obj[chunk_coords] = maybe_value
245-
break
246-
return obj
247-
248-
@classmethod
249-
def create_empty(
250-
cls, chunks_per_shard: tuple[int, ...], buffer_prototype: BufferPrototype | None = None
251-
) -> _ShardBuilder:
252-
if buffer_prototype is None:
253-
buffer_prototype = default_buffer_prototype()
254-
obj = cls()
255-
obj.buf = buffer_prototype.buffer.create_zero_length()
256-
obj.index = _ShardIndex.create_empty(chunks_per_shard)
257-
return obj
258-
259-
def __setitem__(self, chunk_coords: tuple[int, ...], value: Buffer) -> None:
260-
chunk_start = len(self.buf)
261-
chunk_length = len(value)
262-
self.buf += value
263-
self.index.set_chunk_slice(chunk_coords, slice(chunk_start, chunk_start + chunk_length))
264-
265-
def __delitem__(self, chunk_coords: tuple[int, ...]) -> None:
266-
raise NotImplementedError
267-
268-
async def finalize(
269-
self,
270-
index_location: ShardingCodecIndexLocation,
271-
index_encoder: Callable[[_ShardIndex], Awaitable[Buffer]],
272-
) -> Buffer:
273-
index_bytes = await index_encoder(self.index)
274-
if index_location == ShardingCodecIndexLocation.start:
275-
empty_chunks_mask = self.index.offsets_and_lengths[..., 0] == MAX_UINT_64
276-
self.index.offsets_and_lengths[~empty_chunks_mask, 0] += len(index_bytes)
277-
index_bytes = await index_encoder(self.index) # encode again with corrected offsets
278-
out_buf = index_bytes + self.buf
279-
else:
280-
out_buf = self.buf + index_bytes
281-
return out_buf
282-
283-
284-
@dataclass(frozen=True)
285-
class _MergingShardBuilder(ShardMutableMapping):
286-
old_dict: _ShardReader
287-
new_dict: _ShardBuilder
288-
tombstones: set[tuple[int, ...]] = field(default_factory=set)
289-
290-
def __getitem__(self, chunk_coords: tuple[int, ...]) -> Buffer:
291-
chunk_bytes_maybe = self.new_dict.get(chunk_coords)
292-
if chunk_bytes_maybe is not None:
293-
return chunk_bytes_maybe
294-
return self.old_dict[chunk_coords]
295-
296-
def __setitem__(self, chunk_coords: tuple[int, ...], value: Buffer) -> None:
297-
self.new_dict[chunk_coords] = value
298-
299-
def __delitem__(self, chunk_coords: tuple[int, ...]) -> None:
300-
self.tombstones.add(chunk_coords)
301-
302-
def __len__(self) -> int:
303-
return self.old_dict.__len__()
304-
305-
def __iter__(self) -> Iterator[tuple[int, ...]]:
306-
return self.old_dict.__iter__()
307-
308-
def is_empty(self) -> bool:
309-
full_chunk_coords_map = self.old_dict.index.get_full_chunk_map()
310-
full_chunk_coords_map = np.logical_or(
311-
full_chunk_coords_map, self.new_dict.index.get_full_chunk_map()
312-
)
313-
for tombstone in self.tombstones:
314-
full_chunk_coords_map[tombstone] = False
315-
return bool(np.array_equiv(full_chunk_coords_map, False))
316-
317-
async def finalize(
318-
self,
319-
index_location: ShardingCodecIndexLocation,
320-
index_encoder: Callable[[_ShardIndex], Awaitable[Buffer]],
321-
) -> Buffer:
322-
shard_builder = _ShardBuilder.merge_with_morton_order(
323-
self.new_dict.index.chunks_per_shard,
324-
self.tombstones,
325-
self.new_dict,
326-
self.old_dict,
327-
)
328-
return await shard_builder.finalize(index_location, index_encoder)
329-
330222

331223
@dataclass(frozen=True)
332224
class ShardingCodec(
@@ -573,7 +465,7 @@ async def _encode_single(
573465
)
574466
)
575467

576-
shard_builder = _ShardBuilder.create_empty(chunks_per_shard)
468+
shard_builder = dict.fromkeys(morton_order_iter(chunks_per_shard))
577469

578470
await self.codec_pipeline.write(
579471
[
@@ -589,7 +481,11 @@ async def _encode_single(
589481
shard_array,
590482
)
591483

592-
return await shard_builder.finalize(self.index_location, self._encode_shard_index)
484+
return await self._encode_shard_dict(
485+
shard_builder,
486+
chunks_per_shard=chunks_per_shard,
487+
buffer_prototype=default_buffer_prototype(),
488+
)
593489

594490
async def _encode_partial_single(
595491
self,
@@ -603,15 +499,13 @@ async def _encode_partial_single(
603499
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
604500
chunk_spec = self._get_chunk_spec(shard_spec)
605501

606-
shard_dict = _MergingShardBuilder(
607-
await self._load_full_shard_maybe(
608-
byte_getter=byte_setter,
609-
prototype=chunk_spec.prototype,
610-
chunks_per_shard=chunks_per_shard,
611-
)
612-
or _ShardReader.create_empty(chunks_per_shard),
613-
_ShardBuilder.create_empty(chunks_per_shard),
502+
shard_reader = await self._load_full_shard_maybe(
503+
byte_getter=byte_setter,
504+
prototype=chunk_spec.prototype,
505+
chunks_per_shard=chunks_per_shard,
614506
)
507+
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
508+
shard_dict = {k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard)}
615509

616510
indexer = list(
617511
get_indexer(
@@ -632,16 +526,57 @@ async def _encode_partial_single(
632526
],
633527
shard_array,
634528
)
529+
buf = await self._encode_shard_dict(
530+
shard_dict,
531+
chunks_per_shard=chunks_per_shard,
532+
buffer_prototype=default_buffer_prototype(),
533+
)
635534

636-
if shard_dict.is_empty():
535+
if buf is None:
637536
await byte_setter.delete()
638537
else:
639-
await byte_setter.set(
640-
await shard_dict.finalize(
641-
self.index_location,
642-
self._encode_shard_index,
643-
)
644-
)
538+
await byte_setter.set(buf)
539+
540+
async def _encode_shard_dict(
541+
self,
542+
map: ShardMapping,
543+
chunks_per_shard: tuple[int, ...],
544+
buffer_prototype: BufferPrototype,
545+
) -> Buffer | None:
546+
index = _ShardIndex.create_empty(chunks_per_shard)
547+
548+
buffers = []
549+
550+
template = buffer_prototype.buffer.create_zero_length()
551+
chunk_start = 0
552+
for chunk_coords in morton_order_iter(chunks_per_shard):
553+
value = map.get(chunk_coords)
554+
if value is None:
555+
continue
556+
557+
if len(value) == 0:
558+
continue
559+
560+
chunk_length = len(value)
561+
buffers.append(value)
562+
index.set_chunk_slice(chunk_coords, slice(chunk_start, chunk_start + chunk_length))
563+
chunk_start += chunk_length
564+
565+
if len(buffers) == 0:
566+
return None
567+
568+
index_bytes = await self._encode_shard_index(index)
569+
if self.index_location == ShardingCodecIndexLocation.start:
570+
empty_chunks_mask = index.offsets_and_lengths[..., 0] == MAX_UINT_64
571+
index.offsets_and_lengths[~empty_chunks_mask, 0] += len(index_bytes)
572+
index_bytes = await self._encode_shard_index(
573+
index
574+
) # encode again with corrected offsets
575+
buffers.insert(0, index_bytes)
576+
else:
577+
buffers.append(index_bytes)
578+
579+
return template.combine(buffers)
645580

646581
def _is_total_shard(
647582
self, all_chunk_coords: set[tuple[int, ...]], chunks_per_shard: tuple[int, ...]

src/zarr/core/buffer/core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import sys
44
from abc import ABC, abstractmethod
5+
from collections.abc import Iterable
56
from typing import (
67
TYPE_CHECKING,
78
Any,
@@ -294,9 +295,13 @@ def __len__(self) -> int:
294295
return self._data.size
295296

296297
@abstractmethod
298+
def combine(self, others: Iterable[Buffer]) -> Self:
299+
"""Concatenate many buffers"""
300+
...
301+
297302
def __add__(self, other: Buffer) -> Self:
298303
"""Concatenate two buffers"""
299-
...
304+
return self.combine([other])
300305

301306
def __eq__(self, other: object) -> bool:
302307
# Another Buffer class can override this to choose a more efficient path

src/zarr/core/buffer/cpu.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,13 @@ def as_numpy_array(self) -> npt.NDArray[Any]:
107107
"""
108108
return np.asanyarray(self._data)
109109

110-
def __add__(self, other: core.Buffer) -> Self:
111-
"""Concatenate two buffers"""
112-
113-
other_array = other.as_array_like()
114-
assert other_array.dtype == np.dtype("B")
115-
return self.__class__(
116-
np.concatenate((np.asanyarray(self._data), np.asanyarray(other_array)))
117-
)
110+
def combine(self, others: Iterable[core.Buffer]) -> Self:
111+
data = [np.asanyarray(self._data)]
112+
for buf in others:
113+
other_array = buf.as_array_like()
114+
assert other_array.dtype == np.dtype("B")
115+
data.append(np.asanyarray(other_array))
116+
return self.__class__(np.concatenate(data))
118117

119118

120119
class NDBuffer(core.NDBuffer):

src/zarr/core/buffer/gpu.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,15 @@ def from_bytes(cls, bytes_like: BytesLike) -> Self:
107107
def as_numpy_array(self) -> npt.NDArray[Any]:
108108
return cast("npt.NDArray[Any]", cp.asnumpy(self._data))
109109

110-
def __add__(self, other: core.Buffer) -> Self:
111-
other_array = other.as_array_like()
112-
assert other_array.dtype == np.dtype("B")
113-
gpu_other = Buffer(other_array)
114-
gpu_other_array = gpu_other.as_array_like()
115-
return self.__class__(
116-
cp.concatenate((cp.asanyarray(self._data), cp.asanyarray(gpu_other_array)))
117-
)
110+
def combine(self, others: Iterable[core.Buffer]) -> Self:
111+
data = [cp.asanyarray(self._data)]
112+
for other in others:
113+
other_array = other.as_array_like()
114+
assert other_array.dtype == np.dtype("B")
115+
gpu_other = Buffer(other_array)
116+
gpu_other_array = gpu_other.as_array_like()
117+
data.append(cp.asanyarray(gpu_other_array))
118+
return self.__class__(cp.concatenate(data))
118119

119120

120121
class NDBuffer(core.NDBuffer):

0 commit comments

Comments
 (0)