Skip to content

Commit

Permalink
Implement merge_chunks using map_selection
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Oct 31, 2024
1 parent 2eead17 commit 597cbb7
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 155 deletions.
80 changes: 13 additions & 67 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
from cubed.spec import spec_from_config
from cubed.storage.backend import open_backend_array
from cubed.types import T_RegularChunks, T_Shape
from cubed.utils import _concatenate2, array_memory, array_size, get_item
from cubed.utils import array_memory, array_size, get_item
from cubed.utils import numblocks as compute_numblocks
from cubed.utils import offset_to_block_id, to_chunksize
from cubed.vendor.dask.array.core import normalize_chunks
from cubed.vendor.dask.array.utils import validate_axis
from cubed.vendor.dask.blockwise import broadcast_dimensions, lol_product
from cubed.vendor.dask.blockwise import broadcast_dimensions
from cubed.vendor.dask.utils import has_keyword

if TYPE_CHECKING:
Expand Down Expand Up @@ -1107,77 +1107,23 @@ def merge_chunks(x, chunks):
)

target_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype)
return map_direct(
_copy_chunk,
x,
shape=x.shape,
dtype=x.dtype,
chunks=target_chunks,
extra_projected_mem=0,
target_chunks=target_chunks,
)


def _copy_chunk(e, x, target_chunks=None, block_id=None):
if isinstance(x.zarray, dict):
return {
k: numpy_array_to_backend_array(v[get_item(target_chunks, block_id)])
for k, v in x.zarray.items()
}
out = x.zarray[get_item(target_chunks, block_id)]
out = numpy_array_to_backend_array(out)
return out


def merge_chunks_new(x, chunks):
# new implementation that uses general_blockwise rather than map_direct
target_chunksize = chunks
if len(target_chunksize) != x.ndim:
raise ValueError(
f"Chunks {target_chunksize} must have same number of dimensions as array ({x.ndim})"
)
if not all(c1 % c0 == 0 for c0, c1 in zip(x.chunksize, target_chunksize)):
raise ValueError(
f"Chunks {target_chunksize} must be a multiple of array's chunks {x.chunksize}"
)

target_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype)
axes = [
i for (i, (c0, c1)) in enumerate(zip(x.chunksize, target_chunksize)) if c0 != c1
]

def key_function(out_key):
def selection_function(out_key):
out_coords = out_key[1:]
return get_item(target_chunks, out_coords)

in_keys = []
for i, (c0, c1) in enumerate(zip(x.chunksize, target_chunksize)):
k = c1 // c0 # number of blocks to merge in axis i
if k == 1:
in_keys.append(out_coords[i])
else:
start = out_coords[i] * k
stop = min(start + k, x.numblocks[i])
in_keys.append(list(range(start, stop)))

# return a tuple with a single item that is the list of input keys to be merged
return (lol_product((x.name,), in_keys),)

num_input_blocks = (
int(np.prod([c1 // c0 for (c0, c1) in zip(x.chunksize, target_chunksize)])),
max_num_input_blocks = math.prod(
c1 // c0 for c0, c1 in zip(x.chunksize, target_chunksize)
)
iterable_input_blocks = (True,)

return general_blockwise(
_concatenate2,
key_function,
return map_selection(
None, # no function to apply after selection
selection_function,
x,
shapes=[x.shape],
dtypes=[x.dtype],
chunkss=[target_chunks],
extra_projected_mem=0,
num_input_blocks=num_input_blocks,
iterable_input_blocks=iterable_input_blocks,
axes=axes,
x.shape,
x.dtype,
target_chunks,
max_num_input_blocks=max_num_input_blocks,
)


Expand Down
5 changes: 3 additions & 2 deletions cubed/tests/test_mem_utilization.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ def test_index_multiple_axes(tmp_path, spec, executor):

@pytest.mark.slow
def test_index_step(tmp_path, spec, executor):
# use 400MB chunks so that intermediate after indexing has 200MB chunks
a = cubed.random.random(
(10000, 10000), chunks=(5000, 5000), spec=spec
) # 200MB chunks
(20000, 10000), chunks=(10000, 5000), spec=spec
) # 400MB chunks
b = a[::2, :]
run_operation(tmp_path, executor, "index_step", b)

Expand Down
76 changes: 55 additions & 21 deletions cubed/tests/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import cubed.array_api as xp
import cubed.random
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import elemwise, merge_chunks_new, partial_reduce
from cubed.core.ops import elemwise, merge_chunks, partial_reduce
from cubed.core.optimization import (
fuse_all_optimize_dag,
fuse_only_optimize_dag,
Expand Down Expand Up @@ -259,11 +259,19 @@ def add_placeholder_op(dag, inputs, outputs):
add_op(dag, placeholder_func, [a.name for a in inputs], [b.name for b in outputs])


def structurally_equivalent(dag1, dag2):
def structurally_equivalent(dag1, dag2, remove_hidden=False):
# compare structure, and node labels for values but not operators since they are placeholders

# draw_dag(dag1, "dag1") # uncomment for debugging
# draw_dag(dag2, "dag2") # uncomment for debugging
if remove_hidden:
dag1.remove_nodes_from(
list(n for n, d in dag1.nodes(data=True) if d.get("hidden", False))
)
dag2.remove_nodes_from(
list(n for n, d in dag2.nodes(data=True) if d.get("hidden", False))
)

draw_dag(dag1, "dag1") # uncomment for debugging
draw_dag(dag2, "dag2") # uncomment for debugging

labelled_dag1 = nx.convert_node_labels_to_integers(dag1, label_attribute="label")
labelled_dag2 = nx.convert_node_labels_to_integers(dag2, label_attribute="label")
Expand All @@ -282,8 +290,10 @@ def nm(node_attrs1, node_attrs2):
def draw_dag(dag, name="dag"):
dag = dag.copy()
for _, d in dag.nodes(data=True):
if "name" in d: # pydot already has name
del d["name"]
# remove keys or values with possibly unescaped characters
for k in ("name", "pipeline", "primitive_op", "stack_summaries"):
if k in d:
del d[k]
gv = nx.drawing.nx_pydot.to_pydot(dag)
format = "svg"
full_filename = f"{name}.{format}"
Expand Down Expand Up @@ -810,7 +820,7 @@ def test_fuse_large_fan_in_override(spec):
#
def test_fuse_with_merge_chunks_unary(spec):
a = xp.ones((3, 2), chunks=(1, 2), spec=spec)
b = merge_chunks_new(a, chunks=(3, 2))
b = merge_chunks(a, chunks=(3, 2))
c = xp.negative(b)

opt_fn = fuse_one_level(c)
Expand All @@ -822,10 +832,14 @@ def test_fuse_with_merge_chunks_unary(spec):
add_placeholder_op(expected_fused_dag, (), (a,))
add_placeholder_op(expected_fused_dag, (a,), (c,))
optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(b.plan.dag, b.name) == (3,)

# merge_chunks uses a hidden op and array for block ids - ignore when comparing structure
assert structurally_equivalent(
optimized_dag, expected_fused_dag, remove_hidden=True
)
assert get_num_input_blocks(b.plan.dag, b.name) == (3, 1) # final 1 is block ids
assert get_num_input_blocks(c.plan.dag, c.name) == (1,)
assert get_num_input_blocks(optimized_dag, c.name) == (3,)
assert get_num_input_blocks(optimized_dag, c.name) == (3, 1) # final 1 is block ids

result = c.compute(optimize_function=opt_fn)
assert_array_equal(result, -np.ones((3, 2)))
Expand All @@ -842,7 +856,7 @@ def test_fuse_with_merge_chunks_unary(spec):
def test_fuse_with_merge_chunks_binary(spec):
a = xp.ones((3, 2), chunks=(1, 2), spec=spec)
b = xp.ones((3, 2), chunks=(3, 2), spec=spec)
c = merge_chunks_new(a, chunks=(3, 2))
c = merge_chunks(a, chunks=(3, 2))
d = xp.negative(b)
e = xp.add(c, d)

Expand All @@ -856,9 +870,17 @@ def test_fuse_with_merge_chunks_binary(spec):
add_placeholder_op(expected_fused_dag, (), (b,))
add_placeholder_op(expected_fused_dag, (a, b), (e,))
optimized_dag = e.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)

# merge_chunks uses a hidden op and array for block ids - ignore when comparing structure
assert structurally_equivalent(
optimized_dag, expected_fused_dag, remove_hidden=True
)
assert get_num_input_blocks(e.plan.dag, e.name) == (1, 1)
assert get_num_input_blocks(optimized_dag, e.name) == (3, 1)
assert get_num_input_blocks(optimized_dag, e.name) == (
3,
1,
1,
) # final 1 is block ids

result = e.compute(optimize_function=opt_fn)
assert_array_equal(result, np.zeros((3, 2)))
Expand All @@ -875,7 +897,7 @@ def test_fuse_with_merge_chunks_binary(spec):
def test_fuse_merge_chunks_unary(spec):
a = xp.ones((3, 2), chunks=(1, 2), spec=spec)
b = xp.negative(a)
c = merge_chunks_new(b, chunks=(3, 2))
c = merge_chunks(b, chunks=(3, 2))

# specify max_total_num_input_blocks to force c to fuse
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=3)
Expand All @@ -887,10 +909,14 @@ def test_fuse_merge_chunks_unary(spec):
add_placeholder_op(expected_fused_dag, (), (a,))
add_placeholder_op(expected_fused_dag, (a,), (c,))
optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)

# merge_chunks uses a hidden op and array for block ids - ignore when comparing structure
assert structurally_equivalent(
optimized_dag, expected_fused_dag, remove_hidden=True
)
assert get_num_input_blocks(b.plan.dag, b.name) == (1,)
assert get_num_input_blocks(c.plan.dag, c.name) == (3,)
assert get_num_input_blocks(optimized_dag, c.name) == (3,)
assert get_num_input_blocks(c.plan.dag, c.name) == (3, 1) # final 1 is block ids
assert get_num_input_blocks(optimized_dag, c.name) == (3, 1) # final 1 is block ids

result = c.compute(optimize_function=opt_fn)
assert_array_equal(result, -np.ones((3, 2)))
Expand All @@ -908,7 +934,7 @@ def test_fuse_merge_chunks_binary(spec):
a = xp.ones((3, 2), chunks=(1, 2), spec=spec)
b = xp.ones((3, 2), chunks=(1, 2), spec=spec)
c = xp.add(a, b)
d = merge_chunks_new(c, chunks=(3, 2))
d = merge_chunks(c, chunks=(3, 2))

# specify max_total_num_input_blocks to force d to fuse
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=6)
Expand All @@ -921,10 +947,18 @@ def test_fuse_merge_chunks_binary(spec):
add_placeholder_op(expected_fused_dag, (), (b,))
add_placeholder_op(expected_fused_dag, (a, b), (d,))
optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)

# merge_chunks uses a hidden op and array for block ids - ignore when comparing structure
assert structurally_equivalent(
optimized_dag, expected_fused_dag, remove_hidden=True
)
assert get_num_input_blocks(c.plan.dag, c.name) == (1, 1)
assert get_num_input_blocks(d.plan.dag, d.name) == (3,)
assert get_num_input_blocks(optimized_dag, d.name) == (3, 3)
assert get_num_input_blocks(d.plan.dag, d.name) == (3, 1) # final 1 is block ids
assert get_num_input_blocks(optimized_dag, d.name) == (
3,
3,
1,
) # final 1 is block ids

result = d.compute(optimize_function=opt_fn)
assert_array_equal(result, 2 * np.ones((3, 2)))
Expand Down
65 changes: 0 additions & 65 deletions cubed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,71 +338,6 @@ def broadcast_trick(func):
return inner


# From dask.array.core, but changed to use nxp namespace
def _concatenate2(arrays, axes=None):
"""Recursively concatenate nested lists of arrays along axes
Each entry in axes corresponds to each level of the nested list. The
length of axes should correspond to the level of nesting of arrays.
If axes is an empty list or tuple, return arrays, or arrays[0] if
arrays is a list.
>>> x = np.array([[1, 2], [3, 4]])
>>> _concatenate2([x, x], axes=[0])
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]])
>>> _concatenate2([x, x], axes=[1])
array([[1, 2, 1, 2],
[3, 4, 3, 4]])
>>> _concatenate2([[x, x], [x, x]], axes=[0, 1])
array([[1, 2, 1, 2],
[3, 4, 3, 4],
[1, 2, 1, 2],
[3, 4, 3, 4]])
Supports Iterators
>>> _concatenate2(iter([x, x]), axes=[1])
array([[1, 2, 1, 2],
[3, 4, 3, 4]])
Special Case
>>> _concatenate2([x, x], axes=())
array([[1, 2],
[3, 4]])
"""
if axes is None:
axes = []

if axes == ():
if isinstance(arrays, list):
return arrays[0]
else:
return arrays

if isinstance(arrays, Iterator):
arrays = list(arrays)
if not isinstance(arrays, (list, tuple)):
return arrays
if len(axes) > 1:
arrays = [_concatenate2(a, axes=axes[1:]) for a in arrays]
concatenate = nxp.concat
if isinstance(arrays[0], dict):
# Handle concatenation of `dict`s, used as a replacement for structured
# arrays when that's not supported by the array library (e.g., CuPy).
keys = list(arrays[0].keys())
assert all(list(a.keys()) == keys for a in arrays)
ret = dict()
for k in keys:
ret[k] = concatenate(list(a[k] for a in arrays), axis=axes[0])
return ret
else:
return concatenate(arrays, axis=axes[0])


def normalize_shape(shape: Union[int, Tuple[int, ...], None]) -> Tuple[int, ...]:
"""Normalize a `shape` argument to a tuple of ints."""

Expand Down

0 comments on commit 597cbb7

Please sign in to comment.