Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve resume suggestions #3719

Merged
merged 13 commits into from
Apr 2, 2024
178 changes: 129 additions & 49 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
as_completed,
wait,
)
from typing import Any, Iterable, Iterator
from typing import Any, Collection, Iterable, Iterator

from more_itertools import interleave
from pluggy import PluginManager
Expand Down Expand Up @@ -198,98 +198,178 @@ def _suggest_resume_scenario(

postfix = ""
if done_nodes:
node_names = (n.name for n in remaining_nodes)
resume_p = pipeline.only_nodes(*node_names)
start_p = resume_p.only_nodes_with_inputs(*resume_p.inputs())

# find the nearest persistent ancestors of the nodes in start_p
start_p_persistent_ancestors = _find_persistent_ancestors(
pipeline, start_p.nodes, catalog
start_node_names = _find_nodes_to_resume_from(
pipeline=pipeline,
unfinished_nodes=remaining_nodes,
catalog=catalog,
)

start_node_names = (n.name for n in start_p_persistent_ancestors)
postfix += f" --from-nodes \"{','.join(start_node_names)}\""
start_nodes_str = ",".join(sorted(start_node_names))
postfix += f' --from-nodes "{start_nodes_str}"'

if not postfix:
self._logger.warning(
"No nodes ran. Repeat the previous command to attempt a new run."
)
else:
self._logger.warning(
"There are %d nodes that have not run.\n"
f"There are {len(remaining_nodes)} nodes that have not run.\n"
"You can resume the pipeline run from the nearest nodes with "
"persisted inputs by adding the following "
"argument to your previous command:\n%s",
len(remaining_nodes),
postfix,
f"argument to your previous command:\n{postfix}"
)


def _find_persistent_ancestors(
pipeline: Pipeline, children: Iterable[Node], catalog: DataCatalog
def _find_nodes_to_resume_from(
pipeline: Pipeline, unfinished_nodes: Collection[Node], catalog: DataCatalog
) -> set[str]:
"""Given a collection of unfinished nodes in a pipeline using
a certain catalog, find the node names to pass to pipeline.from_nodes()
to cover all unfinished nodes, including any additional nodes
that should be re-run if their outputs are not persisted.

Args:
pipeline: the ``Pipeline`` to find starting nodes for.
unfinished_nodes: collection of ``Node``s that have not finished yet
catalog: the ``DataCatalog`` of the run.

Returns:
Set of node names to pass to pipeline.from_nodes() to continue
the run.

"""
nodes_to_be_run = _find_all_required_nodes(pipeline, unfinished_nodes, catalog)

# Find which of the remaining nodes would need to run first (in topo sort)
persistent_ancestors = _find_initial_node_group(pipeline, nodes_to_be_run)

return {n.name for n in persistent_ancestors}


def _find_all_required_nodes(
DimedS marked this conversation as resolved.
Show resolved Hide resolved
pipeline: Pipeline, unfinished_nodes: Iterable[Node], catalog: DataCatalog
) -> set[Node]:
"""Breadth-first search approach to finding the complete set of
persistent ancestors of an iterable of ``Node``s. Persistent
ancestors exclusively have persisted ``Dataset``s as inputs.
``Node``s which need to run to cover all unfinished nodes,
including any additional nodes that should be re-run if their outputs
are not persisted.

Args:
pipeline: the ``Pipeline`` to find ancestors in.
children: the iterable containing ``Node``s to find ancestors of.
pipeline: the ``Pipeline`` to analyze.
unfinished_nodes: the iterable of ``Node``s which have not finished yet.
catalog: the ``DataCatalog`` of the run.

Returns:
A set containing first persistent ancestors of the given
``Node``s.
A set containing all input unfinished ``Node``s and all remaining
``Node``s that need to run in case their outputs are not persisted.

"""
ancestor_nodes_to_run = set()
queue, visited = deque(children), set(children)
nodes_to_run = set(unfinished_nodes)
initial_nodes = _nodes_with_external_inputs(unfinished_nodes)

queue, visited = deque(initial_nodes), set(initial_nodes)
DimedS marked this conversation as resolved.
Show resolved Hide resolved
while queue:
current_node = queue.popleft()
if _has_persistent_inputs(current_node, catalog):
ancestor_nodes_to_run.add(current_node)
continue
for parent in _enumerate_parents(pipeline, current_node):
if parent in visited:
nodes_to_run.add(current_node)
non_persistent_inputs = _enumerate_non_persistent_inputs(current_node, catalog)
# Look for the nodes that produce non-persistent inputs (if those exist)
for node in _enumerate_nodes_with_outputs(pipeline, non_persistent_inputs):
if node in visited:
merelcht marked this conversation as resolved.
Show resolved Hide resolved
continue
visited.add(parent)
queue.append(parent)
return ancestor_nodes_to_run
visited.add(node)
queue.append(node)

# Make sure no downstream tasks are skipped
nodes_to_run = set(pipeline.from_nodes(*(n.name for n in nodes_to_run)).nodes)
DimedS marked this conversation as resolved.
Show resolved Hide resolved

def _enumerate_parents(pipeline: Pipeline, child: Node) -> list[Node]:
"""For a given ``Node``, returns a list containing the direct parents
of that ``Node`` in the given ``Pipeline``.
return nodes_to_run


def _nodes_with_external_inputs(nodes_of_interest: Iterable[Node]) -> set[Node]:
merelcht marked this conversation as resolved.
Show resolved Hide resolved
"""For given ``Node``s , find their subset which depends on
external inputs of the ``Pipeline`` they constitute. External inputs
are pipeline inputs not produced by other ``Node``s in the ``Pipeline``.

Args:
pipeline: the ``Pipeline`` to search for direct parents in.
child: the ``Node`` to find parents of.
nodes_of_interest: the ``Node``s to analyze.

Returns:
A list of all ``Node``s that are direct parents of ``child``.
A set of ``Node``s that depend on external inputs
of nodes of interest.

"""
parent_pipeline = pipeline.only_nodes_with_outputs(*child.inputs)
return parent_pipeline.nodes
p_nodes_of_interest = Pipeline(nodes_of_interest)
p_nodes_with_external_inputs = p_nodes_of_interest.only_nodes_with_inputs(
*p_nodes_of_interest.inputs()
)
return set(p_nodes_with_external_inputs.nodes)


def _has_persistent_inputs(node: Node, catalog: DataCatalog) -> bool:
"""Check if a ``Node`` exclusively has persisted Datasets as inputs.
If at least one input is a ``MemoryDataset``, return False.
def _enumerate_non_persistent_inputs(node: Node, catalog: DataCatalog) -> set[str]:
"""Enumerate non-persistent input datasets of a ``Node``.

Args:
node: the ``Node`` to check the inputs of.
catalog: the ``DataCatalog`` of the run.

Returns:
True if the ``Node`` being checked exclusively has inputs that
are not ``MemoryDataset``, else False.
Set of names of non-persistent inputs of given ``Node``.

"""
# We use _datasets because they pertain parameter name format
catalog_datasets = catalog._datasets
non_persistent_inputs: set[str] = set()
for node_input in node.inputs:
if isinstance(catalog._datasets[node_input], MemoryDataset):
return False
return True
if node_input.startswith("params:"):
continue

if (
node_input not in catalog_datasets
or catalog_datasets[node_input]._EPHEMERAL
):
non_persistent_inputs.add(node_input)

return non_persistent_inputs


def _enumerate_nodes_with_outputs(
pipeline: Pipeline, outputs: Collection[str]
) -> list[Node]:
"""For given outputs, returns a list containing nodes that
generate them in the given ``Pipeline``.

Args:
pipeline: the ``Pipeline`` to search for nodes in.
outputs: the dataset names to find source nodes for.

Returns:
A list of all ``Node``s that are producing ``outputs``.

"""
parent_pipeline = pipeline.only_nodes_with_outputs(*outputs)
return parent_pipeline.nodes


def _find_initial_node_group(pipeline: Pipeline, nodes: Iterable[Node]) -> list[Node]:
"""Given a collection of ``Node``s in a ``Pipeline``,
find the initial group of ``Node``s to be run (in topological order).

This can be used to define a sub-pipeline with the smallest possible
set of nodes to pass to --from-nodes.

Args:
pipeline: the ``Pipeline`` to search for initial ``Node``s in.
nodes: the ``Node``s to find initial group for.

Returns:
A list of initial ``Node``s to run given inputs (in topological order).

"""
node_names = set(n.name for n in nodes)
if len(node_names) == 0:
return []
sub_pipeline = pipeline.only_nodes(*node_names)
initial_nodes = sub_pipeline.grouped_nodes[0]
return initial_nodes


def run_node(
Expand Down
111 changes: 103 additions & 8 deletions tests/runner/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def identity(arg):
return arg


def first_arg(*args):
return args[0]


def sink(arg):
pass

Expand All @@ -36,7 +40,7 @@ def return_not_serialisable(arg):
return lambda x: x


def multi_input_list_output(arg1, arg2):
def multi_input_list_output(arg1, arg2, arg3=None):
return [arg1, arg2]


Expand Down Expand Up @@ -80,6 +84,9 @@ def _save(arg):
"ds0_B": persistent_dataset,
"ds2_A": persistent_dataset,
"ds2_B": persistent_dataset,
"dsX": persistent_dataset,
"dsY": persistent_dataset,
"params:p": MemoryDataset(1),
}
)

Expand Down Expand Up @@ -148,21 +155,31 @@ def unfinished_outputs_pipeline():

@pytest.fixture
def two_branches_crossed_pipeline():
"""A ``Pipeline`` with an X-shape (two branches with one common node)"""
r"""A ``Pipeline`` with an X-shape (two branches with one common node):

(node1_A) (node1_B)
\ /
(node2)
/ \
(node3_A) (node3_B)
/ \
(node4_A) (node4_B)

"""
return pipeline(
[
node(identity, "ds0_A", "ds1_A", name="node1_A"),
node(identity, "ds0_B", "ds1_B", name="node1_B"),
node(first_arg, "ds0_A", "ds1_A", name="node1_A"),
node(first_arg, "ds0_B", "ds1_B", name="node1_B"),
node(
multi_input_list_output,
["ds1_A", "ds1_B"],
["ds2_A", "ds2_B"],
name="node2",
),
node(identity, "ds2_A", "ds3_A", name="node3_A"),
node(identity, "ds2_B", "ds3_B", name="node3_B"),
node(identity, "ds3_A", "ds4_A", name="node4_A"),
node(identity, "ds3_B", "ds4_B", name="node4_B"),
node(first_arg, "ds2_A", "ds3_A", name="node3_A"),
node(first_arg, "ds2_B", "ds3_B", name="node3_B"),
node(first_arg, "ds3_A", "ds4_A", name="node4_A"),
node(first_arg, "ds3_B", "ds4_B", name="node4_B"),
]
)

Expand All @@ -175,3 +192,81 @@ def pipeline_with_memory_datasets():
node(func=identity, inputs="Input2", outputs="MemOutput2", name="node2"),
]
)


@pytest.fixture
def pipeline_asymmetric():
r"""

(node1)
\
(node3) (node2)
\ /
(node4)

"""
return pipeline(
[
node(first_arg, ["ds0_A"], ["_ds1"], name="node1"),
node(first_arg, ["ds0_B"], ["_ds2"], name="node2"),
node(first_arg, ["_ds1"], ["_ds3"], name="node3"),
node(first_arg, ["_ds2", "_ds3"], ["_ds4"], name="node4"),
]
)


@pytest.fixture
def pipeline_triangular():
r"""

(node1)
| \
| (node2)
| /
(node3)

"""
return pipeline(
[
node(first_arg, ["ds0_A"], ["_ds1_A"], name="node1"),
node(first_arg, ["_ds1_A"], ["ds2_A"], name="node2"),
node(first_arg, ["ds2_A", "_ds1_A"], ["_ds3_A"], name="node3"),
]
)


@pytest.fixture
def empty_pipeline():
return pipeline([])


@pytest.fixture(
params=[(), ("dsX",), ("params:p",)],
ids=[
"no_extras",
"extra_persistent_ds",
"extra_param",
],
)
def two_branches_crossed_pipeline_variable_inputs(request):
"""A ``Pipeline`` with an X-shape (two branches with one common node).
Non-persistent datasets (other than parameters) are prefixed with an underscore.
"""
extra_inputs = list(request.param)

return pipeline(
[
node(first_arg, ["ds0_A"] + extra_inputs, "_ds1_A", name="node1_A"),
node(first_arg, ["ds0_B"] + extra_inputs, "_ds1_B", name="node1_B"),
node(
multi_input_list_output,
["_ds1_A", "_ds1_B"] + extra_inputs,
["ds2_A", "ds2_B"],
name="node2",
),
node(first_arg, ["ds2_A"] + extra_inputs, "_ds3_A", name="node3_A"),
node(first_arg, ["ds2_B"] + extra_inputs, "_ds3_B", name="node3_B"),
node(first_arg, ["_ds3_A"] + extra_inputs, "_ds4_A", name="node4_A"),
node(first_arg, ["_ds3_B"] + extra_inputs, "_ds4_B", name="node4_B"),
]
)
Loading