Skip to content

Commit

Permalink
%load_node experiments (#3568)
Browse files Browse the repository at this point in the history
* Simplify mocking

Signed-off-by: Ahdra Merali <ahdra.merali@quantumblack.com>

* Check node func names

Signed-off-by: Ahdra Merali <ahdra.merali@quantumblack.com>

* Naive fix for return statements

Signed-off-by: Ahdra Merali <ahdra.merali@quantumblack.com>

* Handle nested case

Signed-off-by: Ahdra Merali <ahdra.merali@quantumblack.com>

* Change pipelines fixture type

Signed-off-by: Ahdra Merali <ahdra.merali@quantumblack.com>

* Remove unnecessary TODO

Signed-off-by: Ahdra Merali <ahdra.merali@quantumblack.com>

* Revert "Check node func names"

This reverts commit 63ee194.

Signed-off-by: Ahdra Merali <ahdra.merali@quantumblack.com>

* Replace commented return statements with a display() statement

Signed-off-by: Ahdra Merali <ahdra.merali@quantumblack.com>

* Add warning about node name when node not found

Signed-off-by: Ahdra Merali <ahdra.merali@quantumblack.com>

* Add line about debugging inputs in catalog

Signed-off-by: Ahdra Merali <ahdra.merali@quantumblack.com>

* Lint

Signed-off-by: Ahdra Merali <ahdra.merali@quantumblack.com>

* Change wording

Signed-off-by: Ahdra Merali <ahdra.merali@quantumblack.com>

* Revert "Replace commented return statements with a display() statement"

This reverts commit ad63afc.

Signed-off-by: Ahdra Merali <ahdra.merali@quantumblack.com>

* Revert "Naive fix for return statements"

This reverts commit 04c022e.

Signed-off-by: Ahdra Merali <ahdra.merali@quantumblack.com>

* Update tests

Signed-off-by: Ahdra Merali <ahdra.merali@quantumblack.com>

---------

Signed-off-by: Ahdra Merali <ahdra.merali@quantumblack.com>
  • Loading branch information
AhdraMeraliQB authored Feb 1, 2024
1 parent 9d12c8d commit 5d3b898
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 47 deletions.
9 changes: 7 additions & 2 deletions kedro/ipython/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,9 @@ def _find_node(node_name: str, pipelines: _ProjectPipelines) -> Node:
except ValueError:
continue
# If reached the node was not found in the project
raise ValueError(f"Node with name='{node_name}' not found in any pipelines.")
raise ValueError(
f"Node with name='{node_name}' not found in any pipelines. Remember to specify the node name, not the node function."
)


def _prepare_imports(node_func: Callable) -> str:
Expand Down Expand Up @@ -280,7 +282,10 @@ def _prepare_node_inputs(node: Node) -> str:
node_inputs = node.inputs
func_params = list(signature.parameters)

statements = ["# Prepare necessary inputs for debugging"]
statements = [
"# Prepare necessary inputs for debugging",
"# All debugging inputs must be defined in your project catalog",
]

for node_input, func_param in zip(node_inputs, func_params):
statements.append(f'{func_param} = catalog.load("{node_input}")')
Expand Down
69 changes: 24 additions & 45 deletions tests/ipython/test_ipython.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def my_func:,


@pytest.fixture
def dummy_pipeline(dummy_node):
# return a list of pipelines
return {"dummy": modular_pipeline([dummy_node])}
def dummy_pipelines(dummy_node):
# return a dict of pipelines
return {"dummy_pipeline": modular_pipeline([dummy_node])}


class TestLoadKedroObjects:
Expand Down Expand Up @@ -382,7 +382,7 @@ class MockKedroContext:


class TestLoadNodeMagic:
def test_load_node_magic(self, mocker, dummy_function_file_lines, dummy_pipeline):
def test_load_node_magic(self, mocker, dummy_function_file_lines, dummy_pipelines):
# Reimport `pipelines` from `kedro.framework.project` to ensure that
# it was not removed by prior tests.
from kedro.framework.project import pipelines
Expand All @@ -393,40 +393,22 @@ def test_load_node_magic(self, mocker, dummy_function_file_lines, dummy_pipeline
mocker.patch(
"builtins.open", mocker.mock_open(read_data=dummy_function_file_lines)
)
pipelines.configure("dummy_pipeline") # Setup the pipelines
my_pipelines = dummy_pipeline

def my_register_pipeline():
return my_pipelines

mocker.patch.object(
pipelines,
"_get_pipelines_registry_callable",
return_value=my_register_pipeline,
)
mock_pipeline_values = dummy_pipelines.values()
mocker.patch.object(pipelines, "values", return_value=mock_pipeline_values)

node_to_load = "dummy_node"
magic_load_node(node_to_load)

def test_load_node(self, mocker, dummy_function_file_lines, dummy_pipeline):
def test_load_node(self, mocker, dummy_function_file_lines, dummy_pipelines):
# wraps all the other functions
mocker.patch(
"builtins.open", mocker.mock_open(read_data=dummy_function_file_lines)
)
pipelines.configure("dummy_pipeline") # Setup the pipelines

my_pipelines = dummy_pipeline

def my_register_pipeline():
return my_pipelines

mocker.patch.object(
pipelines,
"_get_pipelines_registry_callable",
return_value=my_register_pipeline,
)
mock_pipeline_values = dummy_pipelines.values()
mocker.patch.object(pipelines, "values", return_value=mock_pipeline_values)

node_inputs = """# Prepare necessary inputs for debugging
# All debugging inputs must be defined in your project catalog
dummy_input = catalog.load("dummy_input")
my_input = catalog.load("extra_input")"""

Expand Down Expand Up @@ -455,20 +437,20 @@ def my_register_pipeline():
for cell, expected_cell in zip(cells_list, expected_cells):
assert cell == expected_cell

def test_find_node(self, mocker, dummy_pipeline, dummy_node):
mocker.patch.object(pipelines, "values", return_value=dummy_pipeline)
def test_find_node(self, dummy_pipelines, dummy_node):
node_to_find = "dummy_node"
result = _find_node(node_to_find, dummy_pipeline)
result = _find_node(node_to_find, dummy_pipelines)
assert result == dummy_node

def test_node_not_found(self, mocker, dummy_pipeline):
mocker.patch.object(pipelines, "values", return_value=dummy_pipeline)
def test_node_not_found(self, dummy_pipelines):
node_to_find = "not_a_node"
dummy_registered_pipelines = dummy_pipelines
with pytest.raises(ValueError) as excinfo:
_find_node(node_to_find, dummy_pipeline)
_find_node(node_to_find, dummy_registered_pipelines)

assert f"Node with name='{node_to_find}' not found in any pipelines." in str(
excinfo.value
assert (
f"Node with name='{node_to_find}' not found in any pipelines. Remember to specify the node name, not the node function."
in str(excinfo.value)
)

def test_prepare_imports(self, mocker, dummy_function_file_lines):
Expand All @@ -493,8 +475,8 @@ def test_prepare_imports_func_not_found(self, mocker):
assert f"Could not find {dummy_function.__name__}" in str(excinfo.value)

def test_prepare_node_inputs(self, dummy_node):
# TODO Ahdra check - does this address parameters properly?
func_inputs = """# Prepare necessary inputs for debugging
# All debugging inputs must be defined in your project catalog
dummy_input = catalog.load("dummy_input")
my_input = catalog.load("extra_input")"""

Expand All @@ -516,16 +498,13 @@ def test_get_lambda_function_body(self, lambda_node):
result = _prepare_function_body(lambda_node.func)
assert result == "func=lambda x: x"

@pytest.mark.skip(reason="Not supported yet")
def test_get_nested_function_body(self):
func_strings = [
"def nested_function(input):",
"\nreturn not input",
"\nreturn nested_function(dummy_input)\n",
]
func_strings = """def nested_function(input):
return not input
return nested_function(dummy_input)"""

result = _prepare_function_body(dummy_nested_function)
assert result == "".join(func_strings)
# TODO fix - fails because skips nested function definition
assert result == func_strings

def test_get_function_with_loop_body(self):
func_strings = """for x in dummy_list:
Expand Down

0 comments on commit 5d3b898

Please sign in to comment.