Skip to content

Commit

Permalink
Introduce FinalizedPlan (#563)
Browse files Browse the repository at this point in the history
The idea here is to formalise the planning process - compose -> finalize (optimize, compile, housekeeping) -> compute/visualize.
  • Loading branch information
tomwhite authored Sep 2, 2024
1 parent d2ba5e1 commit fcd4d21
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 94 deletions.
2 changes: 1 addition & 1 deletion cubed/array_api/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _repr_html_(self):
grid=grid,
nbytes=nbytes,
cbytes=cbytes,
arrs_in_plan=f"{self.plan.num_arrays()} arrays in Plan",
arrs_in_plan=f"{self.plan._finalize().num_arrays()} arrays in Plan",
arrtype="np.ndarray",
)

Expand Down
115 changes: 66 additions & 49 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGrap
"""Compiles functions from all blockwise ops by mutating the input dag."""
# Recommended: make a copy of the dag before calling this function.

compile_with_config = 'config' in inspect.getfullargspec(compile_function).kwonlyargs
compile_with_config = (
"config" in inspect.getfullargspec(compile_function).kwonlyargs
)

for n in dag.nodes:
node = dag.nodes[n]
Expand All @@ -219,31 +221,36 @@ def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGrap
continue

if compile_with_config:
compiled = compile_function(node["pipeline"].config.function, config=node["pipeline"].config)
compiled = compile_function(
node["pipeline"].config.function, config=node["pipeline"].config
)
else:
compiled = compile_function(node["pipeline"].config.function)

# node is a blockwise primitive_op.
# maybe we should investigate some sort of optics library for frozen dataclasses...
new_pipeline = dataclasses.replace(
node["pipeline"],
config=dataclasses.replace(node["pipeline"].config, function=compiled)
config=dataclasses.replace(node["pipeline"].config, function=compiled),
)
node["pipeline"] = new_pipeline

return dag

@lru_cache
def _finalize_dag(
self, optimize_graph: bool = True, optimize_function=None, compile_function: Optional[Decorator] = None,
) -> nx.MultiDiGraph:
def _finalize(
self,
optimize_graph: bool = True,
optimize_function=None,
compile_function: Optional[Decorator] = None,
) -> "FinalizedPlan":
dag = self.optimize(optimize_function).dag if optimize_graph else self.dag
# create a copy since _create_lazy_zarr_arrays mutates the dag
dag = dag.copy()
if callable(compile_function):
dag = self._compile_blockwise(dag, compile_function)
dag = self._create_lazy_zarr_arrays(dag)
return nx.freeze(dag)
return FinalizedPlan(nx.freeze(dag))

def execute(
self,
Expand All @@ -256,7 +263,10 @@ def execute(
spec=None,
**kwargs,
):
dag = self._finalize_dag(optimize_graph, optimize_function, compile_function)
finalized_plan = self._finalize(
optimize_graph, optimize_function, compile_function
)
dag = finalized_plan.dag

compute_id = f"compute-{datetime.now().strftime('%Y%m%dT%H%M%S.%f')}"

Expand All @@ -275,43 +285,6 @@ def execute(
event = ComputeEndEvent(compute_id, dag)
[callback.on_compute_end(event) for callback in callbacks]

def num_tasks(self, optimize_graph=True, optimize_function=None, resume=None):
"""Return the number of tasks needed to execute this plan."""
dag = self._finalize_dag(optimize_graph, optimize_function)
tasks = 0
for _, node in visit_nodes(dag, resume=resume):
tasks += node["primitive_op"].num_tasks
return tasks

def num_arrays(self, optimize_graph: bool = True, optimize_function=None) -> int:
"""Return the number of arrays in this plan."""
dag = self._finalize_dag(optimize_graph, optimize_function)
return sum(d.get("type") == "array" for _, d in dag.nodes(data=True))

def max_projected_mem(
self, optimize_graph=True, optimize_function=None, resume=None
):
"""Return the maximum projected memory across all tasks to execute this plan."""
dag = self._finalize_dag(optimize_graph, optimize_function)
projected_mem_values = [
node["primitive_op"].projected_mem
for _, node in visit_nodes(dag, resume=resume)
]
return max(projected_mem_values) if len(projected_mem_values) > 0 else 0

def total_nbytes_written(
self, optimize_graph: bool = True, optimize_function=None
) -> int:
"""Return the total number of bytes written for all materialized arrays in this plan."""
dag = self._finalize_dag(optimize_graph, optimize_function)
nbytes = 0
for _, d in dag.nodes(data=True):
if d.get("type") == "array":
target = d["target"]
if isinstance(target, LazyZarrArray):
nbytes += target.nbytes
return nbytes

def visualize(
self,
filename="cubed",
Expand All @@ -321,7 +294,8 @@ def visualize(
optimize_function=None,
show_hidden=False,
):
dag = self._finalize_dag(optimize_graph, optimize_function)
finalized_plan = self._finalize(optimize_graph, optimize_function)
dag = finalized_plan.dag
dag = dag.copy() # make a copy since we mutate the DAG below

# remove edges from create-arrays output node to avoid cluttering the diagram
Expand All @@ -336,9 +310,9 @@ def visualize(
"rankdir": rankdir,
"label": (
# note that \l is used to left-justify each line (see https://www.graphviz.org/docs/attrs/nojustify/)
rf"num tasks: {self.num_tasks(optimize_graph, optimize_function)}\l"
rf"max projected memory: {memory_repr(self.max_projected_mem(optimize_graph, optimize_function))}\l"
rf"total nbytes written: {memory_repr(self.total_nbytes_written(optimize_graph, optimize_function))}\l"
rf"num tasks: {finalized_plan.num_tasks()}\l"
rf"max projected memory: {memory_repr(finalized_plan.max_projected_mem())}\l"
rf"total nbytes written: {memory_repr(finalized_plan.total_nbytes_written())}\l"
rf"optimized: {optimize_graph}\l"
),
"labelloc": "bottom",
Expand Down Expand Up @@ -474,6 +448,49 @@ def visualize(
return None


class FinalizedPlan:
"""A plan that is ready to be run.
Finalizing a plan involves the following steps:
1. optimization (optional)
2. adding housekeping nodes to create arrays
3. compiling functions (optional)
4. freezing the final DAG so it can't be changed
"""

def __init__(self, dag):
self.dag = dag

def max_projected_mem(self, resume=None):
"""Return the maximum projected memory across all tasks to execute this plan."""
projected_mem_values = [
node["primitive_op"].projected_mem
for _, node in visit_nodes(self.dag, resume=resume)
]
return max(projected_mem_values) if len(projected_mem_values) > 0 else 0

def num_arrays(self) -> int:
"""Return the number of arrays in this plan."""
return sum(d.get("type") == "array" for _, d in self.dag.nodes(data=True))

def num_tasks(self, resume=None):
"""Return the number of tasks needed to execute this plan."""
tasks = 0
for _, node in visit_nodes(self.dag, resume=resume):
tasks += node["primitive_op"].num_tasks
return tasks

def total_nbytes_written(self) -> int:
"""Return the total number of bytes written for all materialized arrays in this plan."""
nbytes = 0
for _, d in self.dag.nodes(data=True):
if d.get("type") == "array":
target = d["target"]
if isinstance(target, LazyZarrArray):
nbytes += target.nbytes
return nbytes


def arrays_to_dag(*arrays):
from .array import check_array_specs

Expand Down
9 changes: 5 additions & 4 deletions cubed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,13 +373,14 @@ def test_reduction_multiple_rounds(tmp_path, executor):
a = xp.ones((100, 10), dtype=np.uint8, chunks=(1, 10), spec=spec)
b = xp.sum(a, axis=0, dtype=np.uint8)
# check that there is > 1 blockwise step (after optimization)
finalized_plan = b.plan._finalize()
blockwises = [
n
for (n, d) in b.plan.dag.nodes(data=True)
for (n, d) in finalized_plan.dag.nodes(data=True)
if d.get("op_name", None) == "blockwise"
]
assert len(blockwises) > 1
assert b.plan.max_projected_mem() <= 1000
assert finalized_plan.max_projected_mem() <= 1000
assert_array_equal(b.compute(executor=executor), np.ones((100, 10)).sum(axis=0))


Expand Down Expand Up @@ -555,7 +556,7 @@ def test_plan_scaling(tmp_path, factor):
)
c = xp.matmul(a, b)

assert c.plan.num_tasks() > 0
assert c.plan._finalize().num_tasks() > 0
c.visualize(filename=tmp_path / "c")


Expand All @@ -568,7 +569,7 @@ def test_plan_quad_means(tmp_path, t_length):
uv = u * v
m = xp.mean(uv, axis=0, split_every=10, use_new_impl=True)

assert m.plan.num_tasks() > 0
assert m.plan._finalize().num_tasks() > 0
m.visualize(
filename=tmp_path / "quad_means_unoptimized",
optimize_graph=False,
Expand Down
17 changes: 11 additions & 6 deletions cubed/tests/test_executor_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_resume(spec, executor):
d = xp.negative(c)

num_created_arrays = 2 # c, d
assert d.plan.num_tasks(optimize_graph=False) == num_created_arrays + 8
assert d.plan._finalize(optimize_graph=False).num_tasks() == num_created_arrays + 8

task_counter = TaskCounter()
c.compute(executor=executor, callbacks=[task_counter], optimize_graph=False)
Expand Down Expand Up @@ -321,13 +321,15 @@ def test_check_runtime_memory_processes(spec, executor):

try:
from numba import jit as numba_jit

COMPILE_FUNCTIONS.append(numba_jit)
except ModuleNotFoundError:
pass

try:
if 'jax' in os.environ.get('CUBED_BACKEND_ARRAY_API_MODULE', ''):
if "jax" in os.environ.get("CUBED_BACKEND_ARRAY_API_MODULE", ""):
from jax import jit as jax_jit

COMPILE_FUNCTIONS.append(jax_jit)
except ModuleNotFoundError:
pass
Expand All @@ -339,7 +341,8 @@ def test_check_compilation(spec, executor, compile_function):
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
c = xp.add(a, b)
assert_array_equal(
c.compute(executor=executor, compile_function=compile_function), np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]])
c.compute(executor=executor, compile_function=compile_function),
np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]]),
)


Expand All @@ -352,7 +355,7 @@ def compile_function(func):
c = xp.add(a, b)
with pytest.raises(NotImplementedError) as excinfo:
c.compute(executor=executor, compile_function=compile_function)

assert "add" in str(excinfo.value), "Compile function was applied to add operation."


Expand All @@ -365,5 +368,7 @@ def compile_function(func, *, config=None):
c = xp.add(a, b)
with pytest.raises(NotImplementedError) as excinfo:
c.compute(executor=executor, compile_function=compile_function)

assert "BlockwiseSpec" in str(excinfo.value), "Compile function was applied with a config argument."

assert "BlockwiseSpec" in str(
excinfo.value
), "Compile function was applied with a config argument."
Loading

0 comments on commit fcd4d21

Please sign in to comment.