Skip to content

Commit

Permalink
Defer to merge_chunks in special cases of rechunk (#612)
Browse files Browse the repository at this point in the history
* Defer to merge_chunks in special cases of rechunk

* Add top-level rechunk function

* Fix rechunk tests
  • Loading branch information
tomwhite authored Nov 15, 2024
1 parent 55c632f commit c75556f
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 1 deletion.
3 changes: 2 additions & 1 deletion cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .core.array import compute, measure_reserved_mem, visualize
from .core.gufunc import apply_gufunc
from .core.ops import from_array, from_zarr, map_blocks, store, to_zarr
from .core.ops import from_array, from_zarr, map_blocks, rechunk, store, to_zarr
from .nan_functions import nanmean, nansum
from .overlap import map_overlap
from .pad import pad
Expand All @@ -38,6 +38,7 @@
"nanmean",
"nansum",
"pad",
"rechunk",
"store",
"to_zarr",
"visualize",
Expand Down
5 changes: 5 additions & 0 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,11 @@ def rechunk(x, chunks, target_store=None):
return x
# normalizing takes care of dict args for chunks
target_chunks = to_chunksize(normalized_chunks)

# merge chunks special case
if all(c1 % c0 == 0 for c0, c1 in zip(x.chunksize, target_chunks)):
return merge_chunks(x, target_chunks)

name = gensym()
spec = x.spec
if target_store is None:
Expand Down
26 changes: 26 additions & 0 deletions cubed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def test_rechunk(spec, executor, new_chunks, expected_chunks):
def test_rechunk_same_chunks(spec):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 1), spec=spec)
b = a.rechunk((2, 1))
assert b is a
task_counter = TaskCounter()
res = b.compute(callbacks=[task_counter])
# no tasks should have run since chunks are same
Expand All @@ -314,6 +315,31 @@ def test_rechunk_intermediate(tmp_path):
assert_array_equal(b.compute(), np.ones((4, 4)))
intermediates = [n for (n, d) in b.plan.dag.nodes(data=True) if "-int" in d["name"]]
assert len(intermediates) == 1
rechunks = [
n
for (n, d) in b.plan.dag.nodes(data=True)
if d.get("op_name", None) == "rechunk"
]
assert len(rechunks) > 0


def test_rechunk_merge_chunks_optimization():
a = xp.asarray(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
chunks=(2, 1),
)
b = a.rechunk((4, 2))
assert b.chunks == ((4,), (2, 2))
assert_array_equal(
b.compute(),
np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]),
)
rechunks = [
n
for (n, d) in b.plan.dag.nodes(data=True)
if d.get("op_name", None) == "rechunk"
]
assert len(rechunks) == 0


def test_compute_is_idempotent(spec, executor):
Expand Down
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Chunk-specific functions
apply_gufunc
map_blocks
map_overlap
rechunk

Non-standardised functions
==========================
Expand Down

0 comments on commit c75556f

Please sign in to comment.