diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 39121fb2..8600bcd3 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -26,12 +26,12 @@ from cubed.spec import spec_from_config from cubed.storage.backend import open_backend_array from cubed.types import T_RegularChunks, T_Shape -from cubed.utils import _concatenate2, array_memory, array_size, get_item +from cubed.utils import array_memory, array_size, get_item from cubed.utils import numblocks as compute_numblocks from cubed.utils import offset_to_block_id, to_chunksize from cubed.vendor.dask.array.core import normalize_chunks from cubed.vendor.dask.array.utils import validate_axis -from cubed.vendor.dask.blockwise import broadcast_dimensions, lol_product +from cubed.vendor.dask.blockwise import broadcast_dimensions from cubed.vendor.dask.utils import has_keyword if TYPE_CHECKING: @@ -1107,77 +1107,23 @@ def merge_chunks(x, chunks): ) target_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype) - return map_direct( - _copy_chunk, - x, - shape=x.shape, - dtype=x.dtype, - chunks=target_chunks, - extra_projected_mem=0, - target_chunks=target_chunks, - ) - - -def _copy_chunk(e, x, target_chunks=None, block_id=None): - if isinstance(x.zarray, dict): - return { - k: numpy_array_to_backend_array(v[get_item(target_chunks, block_id)]) - for k, v in x.zarray.items() - } - out = x.zarray[get_item(target_chunks, block_id)] - out = numpy_array_to_backend_array(out) - return out - - -def merge_chunks_new(x, chunks): - # new implementation that uses general_blockwise rather than map_direct - target_chunksize = chunks - if len(target_chunksize) != x.ndim: - raise ValueError( - f"Chunks {target_chunksize} must have same number of dimensions as array ({x.ndim})" - ) - if not all(c1 % c0 == 0 for c0, c1 in zip(x.chunksize, target_chunksize)): - raise ValueError( - f"Chunks {target_chunksize} must be a multiple of array's chunks {x.chunksize}" - ) - - target_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype) - axes = [ - i for (i, (c0, c1)) in enumerate(zip(x.chunksize, target_chunksize)) if c0 != c1 - ] - def key_function(out_key): + def selection_function(out_key): out_coords = out_key[1:] + return get_item(target_chunks, out_coords) - in_keys = [] - for i, (c0, c1) in enumerate(zip(x.chunksize, target_chunksize)): - k = c1 // c0 # number of blocks to merge in axis i - if k == 1: - in_keys.append(out_coords[i]) - else: - start = out_coords[i] * k - stop = min(start + k, x.numblocks[i]) - in_keys.append(list(range(start, stop))) - - # return a tuple with a single item that is the list of input keys to be merged - return (lol_product((x.name,), in_keys),) - - num_input_blocks = ( - int(np.prod([c1 // c0 for (c0, c1) in zip(x.chunksize, target_chunksize)])), + max_num_input_blocks = math.prod( + c1 // c0 for c0, c1 in zip(x.chunksize, target_chunksize) ) - iterable_input_blocks = (True,) - return general_blockwise( - _concatenate2, - key_function, + return map_selection( + None, # no function to apply after selection + selection_function, x, - shapes=[x.shape], - dtypes=[x.dtype], - chunkss=[target_chunks], - extra_projected_mem=0, - num_input_blocks=num_input_blocks, - iterable_input_blocks=iterable_input_blocks, - axes=axes, + x.shape, + x.dtype, + target_chunks, + max_num_input_blocks=max_num_input_blocks, ) diff --git a/cubed/tests/test_mem_utilization.py b/cubed/tests/test_mem_utilization.py index a384b90f..be02973c 100644 --- a/cubed/tests/test_mem_utilization.py +++ b/cubed/tests/test_mem_utilization.py @@ -104,9 +104,10 @@ def test_index_multiple_axes(tmp_path, spec, executor): @pytest.mark.slow def test_index_step(tmp_path, spec, executor): + # use 400MB chunks so that intermediate after indexing has 200MB chunks a = cubed.random.random( - (10000, 10000), chunks=(5000, 5000), spec=spec - ) # 200MB chunks + (20000, 10000), chunks=(10000, 5000), spec=spec + ) # 400MB chunks b = a[::2, :] run_operation(tmp_path, executor, "index_step", b) diff --git a/cubed/tests/test_optimization.py b/cubed/tests/test_optimization.py index d3feae77..32101024 100644 --- a/cubed/tests/test_optimization.py +++ b/cubed/tests/test_optimization.py @@ -9,7 +9,7 @@ import cubed.array_api as xp import cubed.random from cubed.backend_array_api import namespace as nxp -from cubed.core.ops import elemwise, merge_chunks_new, partial_reduce +from cubed.core.ops import elemwise, merge_chunks, partial_reduce from cubed.core.optimization import ( fuse_all_optimize_dag, fuse_only_optimize_dag, @@ -259,11 +259,19 @@ def add_placeholder_op(dag, inputs, outputs): add_op(dag, placeholder_func, [a.name for a in inputs], [b.name for b in outputs]) -def structurally_equivalent(dag1, dag2): +def structurally_equivalent(dag1, dag2, remove_hidden=False): # compare structure, and node labels for values but not operators since they are placeholders - # draw_dag(dag1, "dag1") # uncomment for debugging - # draw_dag(dag2, "dag2") # uncomment for debugging + if remove_hidden: + dag1.remove_nodes_from( + list(n for n, d in dag1.nodes(data=True) if d.get("hidden", False)) + ) + dag2.remove_nodes_from( + list(n for n, d in dag2.nodes(data=True) if d.get("hidden", False)) + ) + + draw_dag(dag1, "dag1") # uncomment for debugging + draw_dag(dag2, "dag2") # uncomment for debugging labelled_dag1 = nx.convert_node_labels_to_integers(dag1, label_attribute="label") labelled_dag2 = nx.convert_node_labels_to_integers(dag2, label_attribute="label") @@ -282,8 +290,10 @@ def nm(node_attrs1, node_attrs2): def draw_dag(dag, name="dag"): dag = dag.copy() for _, d in dag.nodes(data=True): - if "name" in d: # pydot already has name - del d["name"] + # remove keys or values with possibly unescaped characters + for k in ("name", "pipeline", "primitive_op", "stack_summaries"): + if k in d: + del d[k] gv = nx.drawing.nx_pydot.to_pydot(dag) format = "svg" full_filename = f"{name}.{format}" @@ -810,7 +820,7 @@ def test_fuse_large_fan_in_override(spec): # def test_fuse_with_merge_chunks_unary(spec): a = xp.ones((3, 2), chunks=(1, 2), spec=spec) - b = merge_chunks_new(a, chunks=(3, 2)) + b = merge_chunks(a, chunks=(3, 2)) c = xp.negative(b) opt_fn = fuse_one_level(c) @@ -822,10 +832,14 @@ def test_fuse_with_merge_chunks_unary(spec): add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a,), (c,)) optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag - assert structurally_equivalent(optimized_dag, expected_fused_dag) - assert get_num_input_blocks(b.plan.dag, b.name) == (3,) + + # merge_chunks uses a hidden op and array for block ids - ignore when comparing structure + assert structurally_equivalent( + optimized_dag, expected_fused_dag, remove_hidden=True + ) + assert get_num_input_blocks(b.plan.dag, b.name) == (3, 1) # final 1 is block ids assert get_num_input_blocks(c.plan.dag, c.name) == (1,) - assert get_num_input_blocks(optimized_dag, c.name) == (3,) + assert get_num_input_blocks(optimized_dag, c.name) == (3, 1) # final 1 is block ids result = c.compute(optimize_function=opt_fn) assert_array_equal(result, -np.ones((3, 2))) @@ -842,7 +856,7 @@ def test_fuse_with_merge_chunks_unary(spec): def test_fuse_with_merge_chunks_binary(spec): a = xp.ones((3, 2), chunks=(1, 2), spec=spec) b = xp.ones((3, 2), chunks=(3, 2), spec=spec) - c = merge_chunks_new(a, chunks=(3, 2)) + c = merge_chunks(a, chunks=(3, 2)) d = xp.negative(b) e = xp.add(c, d) @@ -856,9 +870,17 @@ def test_fuse_with_merge_chunks_binary(spec): add_placeholder_op(expected_fused_dag, (), (b,)) add_placeholder_op(expected_fused_dag, (a, b), (e,)) optimized_dag = e.plan.optimize(optimize_function=opt_fn).dag - assert structurally_equivalent(optimized_dag, expected_fused_dag) + + # merge_chunks uses a hidden op and array for block ids - ignore when comparing structure + assert structurally_equivalent( + optimized_dag, expected_fused_dag, remove_hidden=True + ) assert get_num_input_blocks(e.plan.dag, e.name) == (1, 1) - assert get_num_input_blocks(optimized_dag, e.name) == (3, 1) + assert get_num_input_blocks(optimized_dag, e.name) == ( + 3, + 1, + 1, + ) # final 1 is block ids result = e.compute(optimize_function=opt_fn) assert_array_equal(result, np.zeros((3, 2))) @@ -875,7 +897,7 @@ def test_fuse_with_merge_chunks_binary(spec): def test_fuse_merge_chunks_unary(spec): a = xp.ones((3, 2), chunks=(1, 2), spec=spec) b = xp.negative(a) - c = merge_chunks_new(b, chunks=(3, 2)) + c = merge_chunks(b, chunks=(3, 2)) # specify max_total_num_input_blocks to force c to fuse opt_fn = fuse_multiple_levels(max_total_num_input_blocks=3) @@ -887,10 +909,14 @@ def test_fuse_merge_chunks_unary(spec): add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a,), (c,)) optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag - assert structurally_equivalent(optimized_dag, expected_fused_dag) + + # merge_chunks uses a hidden op and array for block ids - ignore when comparing structure + assert structurally_equivalent( + optimized_dag, expected_fused_dag, remove_hidden=True + ) assert get_num_input_blocks(b.plan.dag, b.name) == (1,) - assert get_num_input_blocks(c.plan.dag, c.name) == (3,) - assert get_num_input_blocks(optimized_dag, c.name) == (3,) + assert get_num_input_blocks(c.plan.dag, c.name) == (3, 1) # final 1 is block ids + assert get_num_input_blocks(optimized_dag, c.name) == (3, 1) # final 1 is block ids result = c.compute(optimize_function=opt_fn) assert_array_equal(result, -np.ones((3, 2))) @@ -908,7 +934,7 @@ def test_fuse_merge_chunks_binary(spec): a = xp.ones((3, 2), chunks=(1, 2), spec=spec) b = xp.ones((3, 2), chunks=(1, 2), spec=spec) c = xp.add(a, b) - d = merge_chunks_new(c, chunks=(3, 2)) + d = merge_chunks(c, chunks=(3, 2)) # specify max_total_num_input_blocks to force d to fuse opt_fn = fuse_multiple_levels(max_total_num_input_blocks=6) @@ -921,10 +947,18 @@ def test_fuse_merge_chunks_binary(spec): add_placeholder_op(expected_fused_dag, (), (b,)) add_placeholder_op(expected_fused_dag, (a, b), (d,)) optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag - assert structurally_equivalent(optimized_dag, expected_fused_dag) + + # merge_chunks uses a hidden op and array for block ids - ignore when comparing structure + assert structurally_equivalent( + optimized_dag, expected_fused_dag, remove_hidden=True + ) assert get_num_input_blocks(c.plan.dag, c.name) == (1, 1) - assert get_num_input_blocks(d.plan.dag, d.name) == (3,) - assert get_num_input_blocks(optimized_dag, d.name) == (3, 3) + assert get_num_input_blocks(d.plan.dag, d.name) == (3, 1) # final 1 is block ids + assert get_num_input_blocks(optimized_dag, d.name) == ( + 3, + 3, + 1, + ) # final 1 is block ids result = d.compute(optimize_function=opt_fn) assert_array_equal(result, 2 * np.ones((3, 2))) diff --git a/cubed/utils.py b/cubed/utils.py index e5837134..3263800b 100644 --- a/cubed/utils.py +++ b/cubed/utils.py @@ -338,71 +338,6 @@ def broadcast_trick(func): return inner -# From dask.array.core, but changed to use nxp namespace -def _concatenate2(arrays, axes=None): - """Recursively concatenate nested lists of arrays along axes - - Each entry in axes corresponds to each level of the nested list. The - length of axes should correspond to the level of nesting of arrays. - If axes is an empty list or tuple, return arrays, or arrays[0] if - arrays is a list. - - >>> x = np.array([[1, 2], [3, 4]]) - >>> _concatenate2([x, x], axes=[0]) - array([[1, 2], - [3, 4], - [1, 2], - [3, 4]]) - - >>> _concatenate2([x, x], axes=[1]) - array([[1, 2, 1, 2], - [3, 4, 3, 4]]) - - >>> _concatenate2([[x, x], [x, x]], axes=[0, 1]) - array([[1, 2, 1, 2], - [3, 4, 3, 4], - [1, 2, 1, 2], - [3, 4, 3, 4]]) - - Supports Iterators - >>> _concatenate2(iter([x, x]), axes=[1]) - array([[1, 2, 1, 2], - [3, 4, 3, 4]]) - - Special Case - >>> _concatenate2([x, x], axes=()) - array([[1, 2], - [3, 4]]) - """ - if axes is None: - axes = [] - - if axes == (): - if isinstance(arrays, list): - return arrays[0] - else: - return arrays - - if isinstance(arrays, Iterator): - arrays = list(arrays) - if not isinstance(arrays, (list, tuple)): - return arrays - if len(axes) > 1: - arrays = [_concatenate2(a, axes=axes[1:]) for a in arrays] - concatenate = nxp.concat - if isinstance(arrays[0], dict): - # Handle concatenation of `dict`s, used as a replacement for structured - # arrays when that's not supported by the array library (e.g., CuPy). - keys = list(arrays[0].keys()) - assert all(list(a.keys()) == keys for a in arrays) - ret = dict() - for k in keys: - ret[k] = concatenate(list(a[k] for a in arrays), axis=axes[0]) - return ret - else: - return concatenate(arrays, axis=axes[0]) - - def normalize_shape(shape: Union[int, Tuple[int, ...], None]) -> Tuple[int, ...]: """Normalize a `shape` argument to a tuple of ints."""