diff --git a/RELEASE.md b/RELEASE.md index fe103256cf..7cb9bcc4a8 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -3,6 +3,7 @@ ## Major features and improvements * Create the debugging line magic `%load_node` for Jupyter Notebook and Jupyter Lab. * Add better IPython, VSCode Notebook support for `%load_node` and minimal support for Databricks. +* Add full Kedro Node input syntax for `%load_node`. ## Bug fixes and other changes * Updated CLI Command `kedro catalog resolve` to work with dataset factories that use `PartitionedDataset`. diff --git a/docs/source/conf.py b/docs/source/conf.py index b8f77dc0f4..983bcd7d2d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -106,7 +106,6 @@ "kedro_docs_style_guide.md", ] - type_targets = { "py:class": ( "object", diff --git a/kedro/ipython/__init__.py b/kedro/ipython/__init__.py index 2a6b2d66d3..d07a30c479 100644 --- a/kedro/ipython/__init__.py +++ b/kedro/ipython/__init__.py @@ -12,7 +12,8 @@ import typing import warnings from pathlib import Path -from typing import Any, Callable +from types import MappingProxyType +from typing import Any, Callable, OrderedDict from IPython.core.getipython import get_ipython from IPython.core.magic import needs_local_scope, register_line_magic @@ -36,6 +37,8 @@ logger = logging.getLogger(__name__) +FunctionParameters = MappingProxyType + def load_ipython_extension(ipython: Any) -> None: """ @@ -45,9 +48,9 @@ def load_ipython_extension(ipython: Any) -> None: See https://ipython.readthedocs.io/en/stable/config/extensions/index.html """ ipython.register_magic_function(magic_reload_kedro, magic_name="reload_kedro") - logger.info("Registered line magic 'reload_kedro'") + logger.info("Registered line magic '%reload_kedro'") ipython.register_magic_function(magic_load_node, magic_name="load_node") - logger.info("Registered line magic 'load_node'") + logger.info("Registered line magic '%load_node'") if _find_kedro_project(Path.cwd()) is None: logger.warning( @@ -225,7 +228,9 @@ def magic_load_node(args: str) -> None: """ parameters = parse_argstring(magic_load_node, args) - cells = _load_node(parameters.node, pipelines) + node_name = parameters.node + + cells = _load_node(node_name, pipelines) run_environment = _guess_run_environment() if run_environment == "jupyter": @@ -240,6 +245,36 @@ def magic_load_node(args: str) -> None: _print_cells(cells) +class _NodeBoundArguments(inspect.BoundArguments): + """Similar to inspect.BoundArguments""" + + def __init__( + self, signature: inspect.Signature, arguments: OrderedDict[str, Any] + ) -> None: + super().__init__(signature, arguments) + + @property + def input_params_dict(self) -> dict[str, str] | None: + """A mapping of {variable name: dataset_name}""" + var_positional_arg_name = self._find_var_positional_arg() + inputs_params_dict = {} + for param, dataset_name in self.arguments.items(): + if param == var_positional_arg_name: + # If the argument is *args, use the dataset name instead + for arg in dataset_name: + inputs_params_dict[arg] = arg + else: + inputs_params_dict[param] = dataset_name + return inputs_params_dict + + def _find_var_positional_arg(self) -> str | None: + """Find the name of the VAR_POSITIONAL argument( *args), if any.""" + for k, v in self.signature.parameters.items(): + if v.kind == inspect.Parameter.VAR_POSITIONAL: + return k + return None + + def _create_cell_with_text(text: str, is_jupyter: bool = True) -> None: if is_jupyter: from ipylab import JupyterFrontEnd @@ -277,16 +312,20 @@ def _load_node(node_name: str, pipelines: _ProjectPipelines) -> list[str]: node = _find_node(node_name, pipelines) node_func = node.func - node_inputs = _prepare_node_inputs(node) - imports = _prepare_imports(node_func) - function_definition = _prepare_function_body(node_func) - function_call = _prepare_function_call(node_func) + imports_cell = _prepare_imports(node_func) + function_definition_cell = _prepare_function_body(node_func) + + node_bound_arguments = _get_node_bound_arguments(node) + inputs_params_mapping = _prepare_node_inputs(node_bound_arguments) + node_inputs_cell = _format_node_inputs_text(inputs_params_mapping) + function_call_cell = _prepare_function_call(node_func, node_bound_arguments) cells: list[str] = [] - cells.append(node_inputs) - cells.append(imports) - cells.append(function_definition) - cells.append(function_call) + if node_inputs_cell: + cells.append(node_inputs_cell) + cells.append(imports_cell) + cells.append(function_definition_cell) + cells.append(function_call_cell) return cells @@ -323,20 +362,37 @@ def _prepare_imports(node_func: Callable) -> str: raise FileNotFoundError(f"Could not find {node_func.__name__}") -def _prepare_node_inputs(node: Node) -> str: +def _get_node_bound_arguments(node: Node) -> _NodeBoundArguments: node_func = node.func + node_inputs = node.inputs + + args, kwargs = Node._process_inputs_for_bind(node_inputs) signature = inspect.signature(node_func) + bound_arguments = signature.bind(*args, **kwargs) + return _NodeBoundArguments(bound_arguments.signature, bound_arguments.arguments) - node_inputs = node.inputs - func_params = list(signature.parameters) +def _prepare_node_inputs( + node_bound_arguments: _NodeBoundArguments, +) -> dict[str, str] | None: + # Remove the *args. For example {'first_arg':'a', 'args': ('b','c')} + # will be loaded as follow: + # first_arg = catalog.load("a") + # b = catalog.load("b") # It doesn't have an arg name, so use the dataset name instead. + # c = catalog.load("c") + return node_bound_arguments.input_params_dict + + +def _format_node_inputs_text(input_params_dict: dict[str, str] | None) -> str | None: statements = [ "# Prepare necessary inputs for debugging", "# All debugging inputs must be defined in your project catalog", ] + if not input_params_dict: + return None - for node_input, func_param in zip(node_inputs, func_params): - statements.append(f'{func_param} = catalog.load("{node_input}")') + for func_param, dataset_name in input_params_dict.items(): + statements.append(f'{func_param} = catalog.load("{dataset_name}")') input_statements = "\n".join(statements) return input_statements @@ -348,13 +404,19 @@ def _prepare_function_body(func: Callable) -> str: return body -def _prepare_function_call(node_func: Callable) -> str: +def _prepare_function_call( + node_func: Callable, node_bound_arguments: _NodeBoundArguments +) -> str: """Prepare the text for the function call.""" func_name = node_func.__name__ - signature = inspect.signature(node_func) - func_params = list(signature.parameters) + args = node_bound_arguments.input_params_dict + kwargs = node_bound_arguments.kwargs # Construct the statement of func_name(a=1,b=2,c=3) - func_args = ", ".join(func_params) - body = f"""{func_name}({func_args})""" + args_str_literal = [f"{node_input}" for node_input in args] if args else [] + kwargs_str_literal = [ + f"{node_input}={dataset_name}" for node_input, dataset_name in kwargs.items() + ] + func_params = ", ".join(args_str_literal + kwargs_str_literal) + body = f"""{func_name}({func_params})""" return body diff --git a/tests/ipython/conftest.py b/tests/ipython/conftest.py index 2f71021ac2..edc7dbc9e8 100644 --- a/tests/ipython/conftest.py +++ b/tests/ipython/conftest.py @@ -14,6 +14,7 @@ from .dummy_function_fixtures import ( dummy_function, dummy_function_with_loop, + dummy_function_with_variable_length, dummy_nested_function, ) @@ -105,6 +106,36 @@ def dummy_node(): ) +@pytest.fixture +def dummy_node_empty_input(): + return node( + func=dummy_function, + inputs=["", ""], + outputs=[None], + name="dummy_node_empty_input", + ) + + +@pytest.fixture +def dummy_node_dict_input(): + return node( + func=dummy_function, + inputs=dict(dummy_input="dummy_input", my_input="extra_input"), + outputs=["dummy_output"], + name="dummy_node_empty_input", + ) + + +@pytest.fixture +def dummy_node_with_variable_length(): + return node( + func=dummy_function_with_variable_length, + inputs=["dummy_input", "extra_input", "first", "second"], + outputs=["dummy_output"], + name="dummy_node_with_variable_length", + ) + + @pytest.fixture def lambda_node(): return node( diff --git a/tests/ipython/dummy_function_fixtures.py b/tests/ipython/dummy_function_fixtures.py index fd9438460b..2b48957a93 100644 --- a/tests/ipython/dummy_function_fixtures.py +++ b/tests/ipython/dummy_function_fixtures.py @@ -37,3 +37,7 @@ def dummy_function_with_loop(dummy_list): for x in dummy_list: continue return len(dummy_list) + + +def dummy_function_with_variable_length(dummy_input, my_input, *args, **kwargs): + pass diff --git a/tests/ipython/test_ipython.py b/tests/ipython/test_ipython.py index 309b05aa55..333dc5e80e 100644 --- a/tests/ipython/test_ipython.py +++ b/tests/ipython/test_ipython.py @@ -7,6 +7,8 @@ from kedro.framework.project import pipelines from kedro.ipython import ( _find_node, + _format_node_inputs_text, + _get_node_bound_arguments, _load_node, _prepare_function_body, _prepare_imports, @@ -199,13 +201,13 @@ def test_load_extension_register_line_magic(self, mocker, ipython): "--conf-source=new_conf", ], ) - def test_reload_kedro_magic_with_valid_arguments(self, mocker, args, ipython): + def test_line_magic_with_valid_arguments(self, mocker, args, ipython): mocker.patch("kedro.ipython._find_kedro_project") mocker.patch("kedro.ipython.reload_kedro") ipython.magic(f"reload_kedro {args}") - def test_reload_kedro_with_invalid_arguments(self, mocker, ipython): + def test_line_magic_with_invalid_arguments(self, mocker, ipython): mocker.patch("kedro.ipython._find_kedro_project") mocker.patch("kedro.ipython.reload_kedro") load_ipython_extension(ipython) @@ -357,13 +359,48 @@ def test_prepare_node_inputs( self, dummy_node, ): - func_inputs = """# Prepare necessary inputs for debugging -# All debugging inputs must be defined in your project catalog -dummy_input = catalog.load("dummy_input") -my_input = catalog.load("extra_input")""" + expected = {"dummy_input": "dummy_input", "my_input": "extra_input"} + + node_bound_arguments = _get_node_bound_arguments(dummy_node) + result = _prepare_node_inputs(node_bound_arguments) + assert result == expected + + def test_prepare_node_inputs_when_input_is_empty( + self, + dummy_node_empty_input, + ): + expected = {"dummy_input": "", "my_input": ""} + + node_bound_arguments = _get_node_bound_arguments(dummy_node_empty_input) + result = _prepare_node_inputs(node_bound_arguments) + assert result == expected + + def test_prepare_node_inputs_with_dict_input( + self, + dummy_node_dict_input, + ): + expected = {"dummy_input": "dummy_input", "my_input": "extra_input"} + + node_bound_arguments = _get_node_bound_arguments(dummy_node_dict_input) + result = _prepare_node_inputs(node_bound_arguments) + assert result == expected - result = _prepare_node_inputs(dummy_node) - assert result == func_inputs + def test_prepare_node_inputs_with_variable_length_args( + self, + dummy_node_with_variable_length, + ): + expected = { + "dummy_input": "dummy_input", + "my_input": "extra_input", + "first": "first", + "second": "second", + } + + node_bound_arguments = _get_node_bound_arguments( + dummy_node_with_variable_length + ) + result = _prepare_node_inputs(node_bound_arguments) + assert result == expected def test_prepare_function_body(self, dummy_function_defintion): result = _prepare_function_body(dummy_function) @@ -430,3 +467,41 @@ def test_load_node_with_other(self, mocker, ipython, run_env): load_ipython_extension(ipython) ipython.magic("load_node dummy_node") spy.assert_called_once() + + +class TestFormatNodeInputsText: + def test_format_node_inputs_text_empty_input(self): + # Test with empty input_params_dict + input_params_dict = {} + expected_output = None + assert _format_node_inputs_text(input_params_dict) == expected_output + + def test_format_node_inputs_text_single_input(self): + # Test with a single input + input_params_dict = {"input1": "dataset1"} + expected_output = ( + "# Prepare necessary inputs for debugging\n" + "# All debugging inputs must be defined in your project catalog\n" + 'input1 = catalog.load("dataset1")' + ) + assert _format_node_inputs_text(input_params_dict) == expected_output + + def test_format_node_inputs_text_multiple_inputs(self): + # Test with multiple inputs + input_params_dict = { + "input1": "dataset1", + "input2": "dataset2", + "input3": "dataset3", + } + expected_output = ( + "# Prepare necessary inputs for debugging\n" + "# All debugging inputs must be defined in your project catalog\n" + 'input1 = catalog.load("dataset1")\n' + 'input2 = catalog.load("dataset2")\n' + 'input3 = catalog.load("dataset3")' + ) + assert _format_node_inputs_text(input_params_dict) == expected_output + + def test_format_node_inputs_text_no_catalog_load(self): + # Test with no catalog.load() statements if input_params_dict is None + assert _format_node_inputs_text(None) is None