From 200ebdb2e80dc9b8ce2d523593815558cae82c69 Mon Sep 17 00:00:00 2001 From: Ankita Katiyar <110245118+ankatiyar@users.noreply.github.com> Date: Thu, 2 Mar 2023 17:49:33 +0000 Subject: [PATCH] Fix overwriting of nested parameters in config by runtime parameters (#2378) * Fix updating of nested params from CLI Signed-off-by: Ankita Katiyar * Convert DictConfig to dict for proper merging Signed-off-by: Ankita Katiyar * revert omegaconf Signed-off-by: Ankita Katiyar * revert utils indent Signed-off-by: Ankita Katiyar * Test for nested params with omegaconf Signed-off-by: Ankita Katiyar * Add test for checking store does not contain DictConfig Signed-off-by: Ankita Katiyar * docslinkcheck + move fn outside Signed-off-by: Ankita Katiyar --------- Signed-off-by: Ankita Katiyar --- docs/source/development/commands_reference.md | 2 +- docs/source/development/linting.md | 2 +- kedro/framework/cli/utils.py | 2 +- kedro/framework/context/context.py | 8 +++--- tests/framework/context/test_context.py | 7 ++++- tests/framework/session/test_session.py | 26 +++++++++++++++++++ 6 files changed, 40 insertions(+), 7 deletions(-) diff --git a/docs/source/development/commands_reference.md b/docs/source/development/commands_reference.md index 9c80fc0ed0..64792e4a1e 100644 --- a/docs/source/development/commands_reference.md +++ b/docs/source/development/commands_reference.md @@ -408,7 +408,7 @@ _This command will be deprecated from Kedro version 0.19.0._ kedro lint ``` -Your project is linted with [`black`](https://github.com/psf/black), [`flake8`](https://gitlab.com/pycqa/flake8) and [`isort`](https://github.com/PyCQA/isort). +Your project is linted with [`black`](https://github.com/psf/black), [`flake8`](https://github.com/PyCQA/flake8) and [`isort`](https://github.com/PyCQA/isort). #### Test your project diff --git a/docs/source/development/linting.md b/docs/source/development/linting.md index 545cf679e2..240f3d30c1 100644 --- a/docs/source/development/linting.md +++ b/docs/source/development/linting.md @@ -8,7 +8,7 @@ consistent. ## Set up linting tools There are a variety of linting tools available to use with your Kedro projects. This guide shows you how to use -[`black`](https://github.com/psf/black), [`flake8`](https://gitlab.com/pycqa/flake8), and +[`black`](https://github.com/psf/black), [`flake8`](https://github.com/PyCQA/flake8), and [`isort`](https://github.com/PyCQA/isort) to lint your Kedro projects. - **`black`** is a [PEP 8](https://peps.python.org/pep-0008/) compliant opinionated Python code formatter. `black` can check for styling inconsistencies and reformat your files in place. diff --git a/kedro/framework/cli/utils.py b/kedro/framework/cli/utils.py index 4105a1e9f0..ab24378870 100644 --- a/kedro/framework/cli/utils.py +++ b/kedro/framework/cli/utils.py @@ -469,7 +469,7 @@ def _split_params(ctx, param, value): ) dot_list.append(item) conf = OmegaConf.from_dotlist(dot_list) - return conf + return OmegaConf.to_container(conf) def _split_load_versions(ctx, param, value): diff --git a/kedro/framework/context/context.py b/kedro/framework/context/context.py index c76ed2ed87..75a782b48b 100644 --- a/kedro/framework/context/context.py +++ b/kedro/framework/context/context.py @@ -7,6 +7,7 @@ from urllib.parse import urlparse from warnings import warn +from omegaconf import DictConfig from pluggy import PluginManager from kedro.config import ConfigLoader, MissingConfigException @@ -154,7 +155,9 @@ def _update_nested_dict(old_dict: Dict[Any, Any], new_dict: Dict[Any, Any]) -> N if key not in old_dict: old_dict[key] = value else: - if isinstance(old_dict[key], dict) and isinstance(value, dict): + if isinstance(old_dict[key], (dict, DictConfig)) and isinstance( + value, (dict, DictConfig) + ): _update_nested_dict(old_dict[key], value) else: old_dict[key] = value @@ -322,8 +325,7 @@ def _add_param_to_feed_dict(param_name, param_value): """ key = f"params:{param_name}" feed_dict[key] = param_value - - if isinstance(param_value, dict): + if isinstance(param_value, (dict, DictConfig)): for key, val in param_value.items(): _add_param_to_feed_dict(f"{param_name}.{key}", val) diff --git a/tests/framework/context/test_context.py b/tests/framework/context/test_context.py index 08bf0f2014..b9df3dce22 100644 --- a/tests/framework/context/test_context.py +++ b/tests/framework/context/test_context.py @@ -10,6 +10,7 @@ import pytest import toml import yaml +from omegaconf import OmegaConf from pandas.util.testing import assert_frame_equal from kedro import __version__ as kedro_version @@ -163,7 +164,6 @@ def dummy_dataframe(): ' --from-nodes "nodes3"' ) - expected_message_head = ( "There are 4 nodes that have not run.\n" "You can resume the pipeline run by adding the following " @@ -484,6 +484,11 @@ def test_validate_layers_error(layers, conflicting_datasets, mocker): {"a": {"a.c": {"a.c.b": 4}}}, {"a": {"a.a": 1, "a.b": 2, "a.c": {"a.c.a": 3, "a.c.b": 4}}}, ), + ( + {"a": OmegaConf.create({"b": 1}), "x": 3}, + {"a": {"c": 2}}, + {"a": {"b": 1, "c": 2}, "x": 3}, + ), ], ) def test_update_nested_dict(old_dict: Dict, new_dict: Dict, expected: Dict): diff --git a/tests/framework/session/test_session.py b/tests/framework/session/test_session.py index 1cccce0426..bb014baafa 100644 --- a/tests/framework/session/test_session.py +++ b/tests/framework/session/test_session.py @@ -2,14 +2,17 @@ import re import subprocess import textwrap +from collections.abc import Mapping from pathlib import Path import pytest import toml import yaml +from omegaconf import OmegaConf from kedro import __version__ as kedro_version from kedro.config import AbstractConfigLoader, ConfigLoader, OmegaConfigLoader +from kedro.framework.cli.utils import _split_params from kedro.framework.context import KedroContext from kedro.framework.project import ( ValidationError, @@ -920,3 +923,26 @@ def test_setup_logging_using_omega_config_loader_class( ).as_posix() actual_log_filepath = call_args["handlers"]["info_file_handler"]["filename"] assert actual_log_filepath == expected_log_filepath + + +def get_all_values(mapping: Mapping): + for value in mapping.values(): + yield value + if isinstance(value, Mapping): + yield from get_all_values(value) + + +@pytest.mark.parametrize("params", ["a=1,b.c=2", "a=1,b=2,c=3", ""]) +def test_no_DictConfig_in_store( + params, + mock_package_name, + fake_project, +): + extra_params = _split_params(None, None, params) + session = KedroSession.create( + mock_package_name, fake_project, extra_params=extra_params + ) + + assert not any( + OmegaConf.is_config(value) for value in get_all_values(session._store) + )