Skip to content

Commit

Permalink
Enable the usage of generator functions in nodes (#2161)
Browse files Browse the repository at this point in the history
* Modify the node and run_node to enable generator nodes

Signed-off-by: Ivan Danov <idanov@users.noreply.github.com>

* Add tests to cover all types of generator functions

Signed-off-by: Ivan Danov <idanov@users.noreply.github.com>

* Fail on running a generator node with async load/save

Signed-off-by: Ivan Danov <idanov@users.noreply.github.com>

* Lint my code changes

Signed-off-by: Ivan Danov <idanov@users.noreply.github.com>

* Add changelog to RELEASE.md

Signed-off-by: Ivan Danov <idanov@users.noreply.github.com>

* Simplify the usage of spy and clarify with a comment

Signed-off-by: Ivan Danov <idanov@users.noreply.github.com>

* Improve error messaging

* Improve readability slightly in certain places

* Correct the expected error message in node run tests

Signed-off-by: Ivan Danov <idanov@users.noreply.github.com>

* Revert the eager evaluation of the result in _from_dict

Generators cannot refer to themselves in their definition, or they will fail when used.

Signed-off-by: Ivan Danov <idanov@users.noreply.github.com>

* Add a comment for the eager evaluation

Signed-off-by: Ivan Danov <idanov@users.noreply.github.com>

Signed-off-by: Ivan Danov <idanov@users.noreply.github.com>
  • Loading branch information
idanov committed Jan 6, 2023
1 parent 93255a3 commit fcf3ab4
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 20 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ test:
test-no-spark:
pytest tests --no-cov --ignore tests/extras/datasets/spark --numprocesses 4 --dist loadfile

test-no-datasets:
pytest tests --no-cov --ignore tests/extras/datasets/ --numprocesses 4 --dist loadfile

e2e-tests:
behave

Expand Down
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
* Added new `OmegaConfLoader` which uses `OmegaConf` for loading and merging configuration.
* Added the `--conf-source` option to `kedro run`, allowing users to specify a source for project configuration for the run.
* Added `omegaconf` syntax as option for `--params`. Keys and values can now be separated by colons or equals signs.
* Added support for generator functions as nodes, i.e. using `yield` instead of return.
* Enable chunk-wise processing in nodes with generator functions.
* Save node outputs after every `yield` before proceeding with next chunk.

## Bug fixes and other changes
* Fix bug where `micropkg` manifest section in `pyproject.toml` isn't recognised as allowed configuration.
Expand Down
1 change: 1 addition & 0 deletions dependency/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ importlib-metadata>=3.6; python_version >= '3.8'
importlib_metadata>=3.6, <5.0; python_version < '3.8' # The "selectable" entry points were introduced in `importlib_metadata` 3.6 and Python 3.10. Bandit on Python 3.7 relies on a library with `importlib_metadata` < 5.0
importlib_resources>=1.3 # The `files()` API was introduced in `importlib_resources` 1.3 and Python 3.9.
jmespath>=0.9.5, <1.0
more_itertools~=9.0
omegaconf~=2.3
pip-tools~=6.12
pluggy~=1.0.0
Expand Down
53 changes: 37 additions & 16 deletions kedro/pipeline/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union
from warnings import warn

from more_itertools import spy, unzip


class Node:
"""``Node`` is an auxiliary class facilitating the operations required to
Expand Down Expand Up @@ -397,38 +399,57 @@ def _run_with_dict(self, inputs: Dict[str, Any], node_inputs: Dict[str, str]):

def _outputs_to_dictionary(self, outputs):
def _from_dict():
if set(self._outputs.keys()) != set(outputs.keys()):
result, iterator = outputs, None
# generator functions are lazy and we need a peek into their first output
if inspect.isgenerator(outputs):
(result,), iterator = spy(outputs)

keys = list(self._outputs.keys())
names = list(self._outputs.values())
if not isinstance(result, dict):
raise ValueError(
f"Failed to save outputs of node {self}.\n"
f"The node output is a dictionary, whereas the "
f"function output is {type(result)}."
)
if set(keys) != set(result.keys()):
raise ValueError(
f"Failed to save outputs of node {str(self)}.\n"
f"The node's output keys {set(outputs.keys())} do not match with "
f"the returned output's keys {set(self._outputs.keys())}."
f"The node's output keys {set(result.keys())} "
f"do not match with the returned output's keys {set(keys)}."
)
return {name: outputs[key] for key, name in self._outputs.items()}
if iterator:
exploded = map(lambda x: tuple(x[k] for k in keys), iterator)
result = unzip(exploded)
else:
# evaluate this eagerly so we can reuse variable name
result = tuple(result[k] for k in keys)
return dict(zip(names, result))

def _from_list():
if not isinstance(outputs, (list, tuple)):
result, iterator = outputs, None
# generator functions are lazy and we need a peek into their first output
if inspect.isgenerator(outputs):
(result,), iterator = spy(outputs)

if not isinstance(result, (list, tuple)):
raise ValueError(
f"Failed to save outputs of node {str(self)}.\n"
f"The node definition contains a list of "
f"outputs {self._outputs}, whereas the node function "
f"returned a '{type(outputs).__name__}'."
f"returned a '{type(result).__name__}'."
)
if len(outputs) != len(self._outputs):
if len(result) != len(self._outputs):
raise ValueError(
f"Failed to save outputs of node {str(self)}.\n"
f"The node function returned {len(outputs)} output(s), "
f"The node function returned {len(result)} output(s), "
f"whereas the node definition contains {len(self._outputs)} "
f"output(s)."
)

return dict(zip(self._outputs, outputs))

if isinstance(self._outputs, dict) and not isinstance(outputs, dict):
raise ValueError(
f"Failed to save outputs of node {self}.\n"
f"The node output is a dictionary, whereas the "
f"function output is not."
)
if iterator:
result = unzip(iterator)
return dict(zip(self._outputs, result))

if self._outputs is None:
return {}
Expand Down
33 changes: 31 additions & 2 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
implementations.
"""

import inspect
import itertools as it
import logging
from abc import ABC, abstractmethod
from collections import deque
Expand All @@ -12,8 +14,9 @@
as_completed,
wait,
)
from typing import Any, Dict, Iterable, List, Set
from typing import Any, Dict, Iterable, Iterator, List, Set

from more_itertools import interleave
from pluggy import PluginManager

from kedro.framework.hooks.manager import _NullPluginManager
Expand Down Expand Up @@ -294,10 +297,22 @@ def run_node(
asynchronously with threads. Defaults to False.
session_id: The session id of the pipeline run.
Raises:
ValueError: Raised if is_async is set to True for nodes wrapping
generator functions.
Returns:
The node argument.
"""
if is_async and inspect.isgeneratorfunction(node.func):
raise ValueError(
f"Async data loading and saving does not work with "
f"nodes wrapping generator functions. Please make "
f"sure you don't use `yield` anywhere "
f"in node {str(node)}."
)

if is_async:
node = _run_node_async(node, catalog, hook_manager, session_id)
else:
Expand Down Expand Up @@ -399,7 +414,21 @@ def _run_node_sequential(
node, catalog, inputs, is_async, hook_manager, session_id=session_id
)

for name, data in outputs.items():
items: Iterable = outputs.items()
# if all outputs are iterators, then the node is a generator node
if all(isinstance(d, Iterator) for d in outputs.values()):
# Python dictionaries are ordered so we are sure
# the keys and the chunk streams are in the same order
# [a, b, c]
keys = list(outputs.keys())
# [Iterator[chunk_a], Iterator[chunk_b], Iterator[chunk_c]]
streams = list(outputs.values())
# zip an endless cycle of the keys
# with an interleaved iterator of the streams
# [(a, chunk_a), (b, chunk_b), ...] until all outputs complete
items = zip(it.cycle(keys), interleave(*streams))

for name, data in items:
hook_manager.hook.before_dataset_saved(dataset_name=name, data=data)
catalog.save(name, data)
hook_manager.hook.after_dataset_saved(dataset_name=name, data=data)
Expand Down
4 changes: 2 additions & 2 deletions tests/pipeline/test_node_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def test_run_dict_diff_size(self, mocked_dataset):

class TestNodeRunInvalidOutput:
def test_miss_matching_output_types(self, mocked_dataset):
pattern = r"The node output is a dictionary, whereas the function "
pattern += r"output is not\."
pattern = "The node output is a dictionary, whereas the function "
pattern += "output is <class 'kedro.io.lambda_dataset.LambdaDataSet'>."
with pytest.raises(ValueError, match=pattern):
node(one_in_one_out, "ds1", dict(a="ds")).run(dict(ds1=mocked_dataset))

Expand Down
89 changes: 89 additions & 0 deletions tests/runner/test_run_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest

from kedro.framework.hooks.manager import _NullPluginManager
from kedro.pipeline import node
from kedro.runner import run_node


def generate_one():
yield from range(10)


def generate_tuple():
for i in range(10):
yield i, i * i


def generate_list():
for i in range(10):
yield [i, i * i]


def generate_dict():
for i in range(10):
yield dict(idx=i, square=i * i)


class TestRunGeneratorNode:
def test_generator_fail_async(self, mocker, catalog):
fake_dataset = mocker.Mock()
catalog.add("result", fake_dataset)
n = node(generate_one, inputs=None, outputs="result")

with pytest.raises(Exception, match="nodes wrapping generator functions"):
run_node(n, catalog, _NullPluginManager(), is_async=True)

def test_generator_node_one(self, mocker, catalog):
fake_dataset = mocker.Mock()
catalog.add("result", fake_dataset)
n = node(generate_one, inputs=None, outputs="result")
run_node(n, catalog, _NullPluginManager())

expected = [((i,),) for i in range(10)]
assert 10 == fake_dataset.save.call_count
assert fake_dataset.save.call_args_list == expected

def test_generator_node_tuple(self, mocker, catalog):
left = mocker.Mock()
right = mocker.Mock()
catalog.add("left", left)
catalog.add("right", right)
n = node(generate_tuple, inputs=None, outputs=["left", "right"])
run_node(n, catalog, _NullPluginManager())

expected_left = [((i,),) for i in range(10)]
expected_right = [((i * i,),) for i in range(10)]
assert 10 == left.save.call_count
assert left.save.call_args_list == expected_left
assert 10 == right.save.call_count
assert right.save.call_args_list == expected_right

def test_generator_node_list(self, mocker, catalog):
left = mocker.Mock()
right = mocker.Mock()
catalog.add("left", left)
catalog.add("right", right)
n = node(generate_list, inputs=None, outputs=["left", "right"])
run_node(n, catalog, _NullPluginManager())

expected_left = [((i,),) for i in range(10)]
expected_right = [((i * i,),) for i in range(10)]
assert 10 == left.save.call_count
assert left.save.call_args_list == expected_left
assert 10 == right.save.call_count
assert right.save.call_args_list == expected_right

def test_generator_node_dict(self, mocker, catalog):
left = mocker.Mock()
right = mocker.Mock()
catalog.add("left", left)
catalog.add("right", right)
n = node(generate_dict, inputs=None, outputs=dict(idx="left", square="right"))
run_node(n, catalog, _NullPluginManager())

expected_left = [((i,),) for i in range(10)]
expected_right = [((i * i,),) for i in range(10)]
assert 10 == left.save.call_count
assert left.save.call_args_list == expected_left
assert 10 == right.save.call_count
assert right.save.call_args_list == expected_right

0 comments on commit fcf3ab4

Please sign in to comment.