Skip to content

Commit

Permalink
Fail during planning if map_blocks drop_axis is for a chunked dimensi…
Browse files Browse the repository at this point in the history
…on (#569)
  • Loading branch information
tomwhite authored Sep 9, 2024
1 parent ecfb10f commit 7b32304
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 4 deletions.
9 changes: 9 additions & 0 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,15 @@ def make_blockwise_key_function(
False,
)

for axes, (arg, _) in zip(concat_axes, argpairs):
for ax in axes:
if numblocks[arg][ax] > 1:
raise ValueError(
f"Cannot have multiple chunks in dropped axis {ax}. "
"To fix, use a reduction after calling map_blocks "
"without specifying drop_axis, or rechunk first."
)

def key_function(out_key):
out_coords = out_key[1:]

Expand Down
8 changes: 4 additions & 4 deletions cubed/tests/primitive/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,11 @@ def test_make_blockwise_key_function_contract():
func = lambda x: 0

key_fn = make_blockwise_key_function(
func, "z", "ik", "x", "ij", "y", "jk", numblocks={"x": (2, 2), "y": (2, 2)}
func, "z", "ik", "x", "ij", "y", "jk", numblocks={"x": (2, 1), "y": (1, 2)}
)

graph = make_blockwise_graph(
func, "z", "ik", "x", "ij", "y", "jk", numblocks={"x": (2, 2), "y": (2, 2)}
func, "z", "ik", "x", "ij", "y", "jk", numblocks={"x": (2, 1), "y": (1, 2)}
)
check_consistent_with_graph(key_fn, graph)

Expand All @@ -290,10 +290,10 @@ def test_make_blockwise_key_function_contract_0d():
func = lambda x: 0

key_fn = make_blockwise_key_function(
func, "z", "", "x", "ij", numblocks={"x": (2, 2)}
func, "z", "", "x", "ij", numblocks={"x": (1, 1)}
)

graph = make_blockwise_graph(func, "z", "", "x", "ij", numblocks={"x": (2, 2)})
graph = make_blockwise_graph(func, "z", "", "x", "ij", numblocks={"x": (1, 1)})
check_consistent_with_graph(key_fn, graph)


Expand Down
22 changes: 22 additions & 0 deletions cubed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,28 @@ def func(x, y):
assert_array_equal(c.compute(), np.array([[[12, 13]]]))


def test_map_blocks_drop_axis_chunking(spec):
# This tests the case illustrated in https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html
# Unlike Dask, Cubed does not support concatenating chunks, and will fail if the dropped axis has multiple chunks.

def func(x):
return nxp.sum(x, axis=2)

an = np.arange(8 * 6 * 2).reshape((8, 6, 2))

# single chunk in axis=2 works fine
a = xp.asarray(an, chunks=(5, 4, 2), spec=spec)
b = cubed.map_blocks(func, a, drop_axis=2)
assert_array_equal(b.compute(), np.sum(an, axis=2))

# multiple chunks in axis=2 raises
a = xp.asarray(an, chunks=(5, 4, 1), spec=spec)
with pytest.raises(
ValueError, match=r"Cannot have multiple chunks in dropped axis 2."
):
cubed.map_blocks(func, a, drop_axis=2)


def test_map_blocks_with_non_cubed_array(spec):
a = xp.arange(10, dtype="int64", chunks=(2,), spec=spec)
b = np.array([1, 2], dtype="int64") # numpy array will be coerced to cubed
Expand Down

0 comments on commit 7b32304

Please sign in to comment.