Skip to content

Commit

Permalink
fixes per review
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Nov 22, 2022
1 parent 63568f0 commit f55e0b6
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 26 deletions.
66 changes: 42 additions & 24 deletions rechunker/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from math import ceil, floor, prod
from typing import List, Optional, Sequence, Tuple

import numpy as np
from rechunker.compat import lcm

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -113,18 +114,17 @@ def calculate_stage_chunks(
read_chunks: Tuple[int, ...],
write_chunks: Tuple[int, ...],
stage_count: int = 1,
epsilon: float = 1e-8,
) -> List[Tuple[int, ...]]:
"""
Calculate chunks after each stage of a multi-stage rechunking.
Each stage consists of "split" step followed by a "consolidate" step.
The strategy used here is to progressively enlarge or shrink chunks along
each dimension by the same multiple in each stage. This should roughly
minimize the total number of arrays resulting from "split" steps in a
multi-stage pipeline. It also keeps the total number of elements in each
chunk constant, up to rounding error, so memory usage should remain
each dimension by the same multiple in each stage (geometric spacing). This
should roughly minimize the total number of arrays resulting from "split"
steps in a multi-stage pipeline. It also keeps the total number of elements
in each chunk constant, up to rounding error, so memory usage should remain
constant.
Examples::
Expand All @@ -136,22 +136,35 @@ def calculate_stage_chunks(
>>> calculate_stage_chunks((1_000_000, 1), (1, 1_000_000), stage_count=4)
[(31623, 32), (1000, 1000), (32, 31623)]
TODO: consider more sophisticated algorithms.
TODO: consider more sophisticated algorithms. In particular, exact geometric
spacing often requires irregular intermediate chunk sizes, which (currently)
cannot be stored in Zarr arrays.
"""
stage_chunks = []
for stage in range(1, stage_count):
power = stage / stage_count
# Add a small floating-point epsilon so we don't inadvertently
# round-down even chunk-sizes.
chunks = tuple(
floor(rc ** (1 - power) * wc**power + epsilon)
for rc, wc in zip(read_chunks, write_chunks)
)
stage_chunks.append(chunks)
return stage_chunks
approx_stages = np.geomspace(read_chunks, write_chunks, num=stage_count + 1)
return [tuple(floor(c) for c in stage) for stage in approx_stages[1:-1]]


def _count_intermediate_chunks(source_chunk: int, target_chunk: int, size: int) -> int:
"""Count intermediate chunks required for rechunking along a dimension.
Intermediate chunks must divide both the source and target chunks, and in
general do not need to have a regular size. The number of intermediate
chunks is proportional to the number of required read/write operations.
For example, suppose we want to rechunk an array of size 20 from size 5
chunks to size 7 chunks. We can draw out how the array elements are divided:
0 1 2 3 4|5 6 7 8 9|10 11 12 13 14|15 16 17 18 19 (4 chunks)
0 1 2 3 4 5 6|7 8 9 10 11 12 13|14 15 16 17 18 19 (3 chunks)
To transfer these chunks, we would need to divide the array into irregular
intermediate chunks that fit into both the source and target:
0 1 2 3 4|5 6|7 8 9|10 11 12 13|14|15 16 17 18 19 (6 chunks)
This matches what ``_count_intermediate_chunks()`` calculates::
def _count_num_splits(source_chunk: int, target_chunk: int, size: int) -> int:
>>> _count_intermediate_chunks(5, 7, 20)
6
"""
multiple = lcm(source_chunk, target_chunk)
splits_per_lcm = multiple // source_chunk + multiple // target_chunk - 1
lcm_count, remainder = divmod(size, multiple)
Expand All @@ -167,8 +180,8 @@ def _count_num_splits(source_chunk: int, target_chunk: int, size: int) -> int:
def calculate_single_stage_io_ops(
shape: Sequence[int], in_chunks: Sequence[int], out_chunks: Sequence[int]
) -> int:
"""Estimate the number of irregular chunks required for rechunking."""
return prod(map(_count_num_splits, in_chunks, out_chunks, shape))
"""Count the number of read/write operations required for rechunking."""
return prod(map(_count_intermediate_chunks, in_chunks, out_chunks, shape))


# not a tight upper bound, but ensures that the loop in
Expand All @@ -193,7 +206,12 @@ def multistage_rechunking_plan(
consolidate_reads: bool = True,
consolidate_writes: bool = True,
) -> _MultistagePlan:
"""A rechunking plan that can use multiple split/consolidate steps."""
"""Caculate a rechunking plan that can use multiple split/consolidate steps.
For best results, max_mem should be significantly larger than min_mem (e.g.,
10x). Otherwise an excessive number of rechunking steps will be required.
"""

ndim = len(shape)
if len(source_chunks) != ndim:
raise ValueError(f"source_chunks {source_chunks} must have length {ndim}")
Expand All @@ -212,9 +230,9 @@ def multistage_rechunking_plan(
f"Target chunk memory ({target_chunk_mem}) exceeds max_mem ({max_mem})"
)

if max_mem < min_mem: # basic sanity test
if max_mem < min_mem: # basic sanity check
raise ValueError(
"max_mem ({max_mem}) cannot be smaller than min_mem ({min_mem})"
f"max_mem ({max_mem}) cannot be smaller than min_mem ({min_mem})"
)

if consolidate_writes:
Expand Down Expand Up @@ -321,7 +339,7 @@ def rechunking_plan(
Target chunk shape (must be in form (5, 10, 20), no irregular chunks)
itemsize: int
Number of bytes used to represent a single array element
max_mem : Int
max_mem : int
Maximum permissible chunk memory size, measured in units of itemsize
consolidate_reads: bool, optional
Whether to apply read chunk consolidation
Expand Down
7 changes: 5 additions & 2 deletions tests/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,19 @@ def test_calculate_stage_chunks(read_chunks, write_chunks, stage_count, expected
@pytest.mark.parametrize(
"shape, in_chunks, out_chunks, expected",
[
# simple 1d cases
((6,), (1,), (6,), 6),
((10,), (1,), (6,), 10),
((6,), (2,), (3,), 4),
((24,), (2,), (3,), 16),
((10,), (4,), (5,), 4),
((100,), (4,), (5,), 40),
# simple 2d cases
((100, 100), (1, 100), (100, 1), 10_000),
((100, 100), (1, 10), (10, 1), 10_000),
((100, 100), (20, 20), (25, 25), (5 + 3) ** 2),
((50, 50), (20, 20), (25, 25), ((5 + 3) // 2) ** 2),
((100, 100), (20, 20), (25, 25), 8**2),
((50, 50), (20, 20), (25, 25), 4**2),
# edge cases where one chunk size is 43 (a prime)
((100,), (43,), (100,), 3),
((100,), (43,), (51,), 4),
((100,), (43,), (40,), 5),
Expand Down

0 comments on commit f55e0b6

Please sign in to comment.