From 5484b9b4bed1878c0dc25fc4feae2f2bd731cab4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lorena=20B=C4=83lan?= Date: Tue, 25 Aug 2020 09:26:40 +0100 Subject: [PATCH] [KED-1959] Implement hook integration for other library components: ConfigLoader (#761) --- .circleci/config.yml | 3 +- RELEASE.md | 6 +- .../02_configuration.md | 7 +- docs/source/07_extend_kedro/04_hooks.md | 1 + features/steps/hooks_template.py | 5 + features/windows_reqs.txt | 12 + kedro/config/templated_config.py | 23 +- kedro/framework/context/context.py | 10 +- kedro/framework/hooks/manager.py | 3 +- kedro/framework/hooks/specs.py | 38 +- .../hooks.py | 7 +- setup.cfg | 1 + test_requirements.txt | 1 + tests/framework/context/test_context.py | 363 ++++++++---------- tests/framework/context/test_static_data.py | 153 ++++++++ tests/framework/hooks/test_context_hooks.py | 68 +++- 16 files changed, 438 insertions(+), 263 deletions(-) create mode 100644 features/windows_reqs.txt create mode 100644 tests/framework/context/test_static_data.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 6b446e7668..adf70a9c4f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -222,8 +222,7 @@ commands: name: Install dependencies command: | conda activate kedro_builder - cat *requirements.txt | Select-String -Pattern behave,psutil,requests[^\-],^pandas[^\-],cachetools,pluggy,toposort,yaml | %{ $_ -Replace "#.*", "" } > e2e.txt - pip install -r e2e.txt + pip install -r features/windows_reqs.txt choco install make - run: name: Run e2e tests diff --git a/RELEASE.md b/RELEASE.md index d0b7f03b23..1ee5ba4a12 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -11,8 +11,12 @@ # Upcoming Release 0.16.5 ## Major features and improvements -* Added `register_pipelines()`, a new hook to register a project's pipelines. The order of execution is: plugin hooks, `.kedro.yml` hooks, hooks in `ProjectContext.hooks`. * Added support for `pyproject.toml` to configure Kedro. `pyproject.toml` is used if `.kedro.yml` doesn't exist (Kedro configuration should be under `[tool.kedro]` section). +* Projects created with this version will have no `pipeline.py`, having been replaced by `hooks.py`. +* Added a set of registration hooks, as the new way of registering library components with a Kedro project: + * `register_pipelines()`, to replace `_get_pipelines()` + * `register_config_loader(conf_paths)`, to replace `_create_config_loader()` +These can be defined in `src//hooks.py` and added to `.kedro.yml` (or `pyproject.toml`). The order of execution is: plugin hooks, `.kedro.yml` hooks, hooks in `ProjectContext.hooks`. ## Bug fixes and other changes * `project_name`, `project_version` and `package_name` now have to be defined in `.kedro.yml` for the projects generated using Kedro 0.16.5+. diff --git a/docs/source/04_kedro_project_setup/02_configuration.md b/docs/source/04_kedro_project_setup/02_configuration.md index a96887b110..b83d75c1fa 100644 --- a/docs/source/04_kedro_project_setup/02_configuration.md +++ b/docs/source/04_kedro_project_setup/02_configuration.md @@ -71,14 +71,15 @@ export KEDRO_ENV=test ## Templating configuration -Kedro also provides an extension [TemplatedConfigLoader](/kedro.config.TemplatedConfigLoader) class that allows to template values in your configuration files. `TemplatedConfigLoader` is available in `kedro.config`, to apply templating to your `ProjectContext` in `src//run.py`, you will need to overwrite the `_create_config_loader` method as follows: +Kedro also provides an extension [TemplatedConfigLoader](/kedro.config.TemplatedConfigLoader) class that allows to template values in your configuration files. `TemplatedConfigLoader` is available in `kedro.config`, to apply templating to your project, you will need to update the `register_config_loader` hook implementation in your `src//hooks.py`: ```python from kedro.config import TemplatedConfigLoader # new import -class ProjectContext(KedroContext): - def _create_config_loader(self, conf_paths: Iterable[str]) -> TemplatedConfigLoader: +class ProjectHooks: + @hook_impl + def register_config_loader(self, conf_paths: Iterable[str]) -> ConfigLoader: return TemplatedConfigLoader( conf_paths, globals_pattern="*globals.yml", # read the globals dictionary from project config diff --git a/docs/source/07_extend_kedro/04_hooks.md b/docs/source/07_extend_kedro/04_hooks.md index 16ea598564..6d7f74449b 100644 --- a/docs/source/07_extend_kedro/04_hooks.md +++ b/docs/source/07_extend_kedro/04_hooks.md @@ -50,6 +50,7 @@ The naming convention for error hooks is `on__error`, in which: In addition, Kedro defines Hook specifications to register certain library components to be used with the project. This is where users can define their custom class implementations. Currently, the following Hook specifications are provided: * `register_pipelines` +* `register_config_loader` The naming convention for registration hooks is `register_`. diff --git a/features/steps/hooks_template.py b/features/steps/hooks_template.py index 95a532cd63..30fe18c1d3 100644 --- a/features/steps/hooks_template.py +++ b/features/steps/hooks_template.py @@ -28,6 +28,7 @@ import pandas as pd +from kedro.config import ConfigLoader from kedro.framework.hooks import hook_impl from kedro.pipeline import Pipeline, node @@ -80,5 +81,9 @@ def register_pipelines(self): # pylint: disable=no-self-use return {"__default__": example_pipeline} + @hook_impl + def register_config_loader(self, conf_paths): # pylint: disable=no-self-use + return ConfigLoader(conf_paths) + project_hooks = ProjectHooks() diff --git a/features/windows_reqs.txt b/features/windows_reqs.txt new file mode 100644 index 0000000000..c2905bf63f --- /dev/null +++ b/features/windows_reqs.txt @@ -0,0 +1,12 @@ +# same versions as `test_requirements` +# e2e tests on Windows are slow but we don't need to install +# everything, so just this subset will be enough for CI +behave==1.2.6 +cachetools~=4.1 +jmespath>=0.9.5, <1.0 +pandas>=0.24.0, <1.0.4 +pluggy~=0.13.0 +psutil==5.6.6 +requests~=2.20 +toposort~=1.5 +PyYAML>=4.2, <6.0 diff --git a/kedro/config/templated_config.py b/kedro/config/templated_config.py index 7297d88fed..bd2b4ba684 100644 --- a/kedro/config/templated_config.py +++ b/kedro/config/templated_config.py @@ -55,34 +55,25 @@ class TemplatedConfigLoader(ConfigLoader): wrapped in brackets like: ${...}, to be automatically formatted based on the configs. - The easiest way to use this class is by incorporating it into the - ``KedroContext``. This can be done by extending the ``KedroContext`` and overwriting - the config_loader method, making it return a ``TemplatedConfigLoader`` - object instead of a ``ConfigLoader`` object. - - For this method to work, the context_path variable in `.kedro.yml` (if exists) or - in `pyproject.toml` under `[tool.kedro]` section needs to be pointing at this newly - created class. The `run.py` script has an extension of the ``KedroContext`` by default, - called the ``ProjectContext``. + The easiest way to use this class is by registering it into the + ``KedroContext`` using hooks. This can be done by updating the + hook implementation `register_config_loader` in `hooks.py`, making it return + a ``TemplatedConfigLoader`` object instead of a ``ConfigLoader`` object. Example: :: - >>> from kedro.framework.context import KedroContext, load_context >>> from kedro.config import TemplatedConfigLoader >>> >>> - >>> class MyNewContext(KedroContext): - >>> - >>> def _create_config_loader(self, conf_paths: Iterable[str]) -> TemplatedConfigLoader: + >>> class ProjectHooks: + >>> @hook_impl + >>> def register_config_loader(self, conf_paths: Iterable[str]) -> ConfigLoader: >>> return TemplatedConfigLoader( >>> conf_paths, >>> globals_pattern="*globals.yml", >>> globals_dict={"param1": "pandas.CSVDataSet"} >>> ) - >>> - >>> my_context = load_context(Path.cwd(), env=env) - >>> my_context.run(tags, runner, node_names, from_nodes, to_nodes) The contents of the dictionary resulting from the `globals_pattern` get merged with the ``globals_dict``. In case of conflicts, the keys in diff --git a/kedro/framework/context/context.py b/kedro/framework/context/context.py index 3981726498..28ea4554c7 100644 --- a/kedro/framework/context/context.py +++ b/kedro/framework/context/context.py @@ -252,10 +252,10 @@ def __init__( self.env = env or "local" self._extra_params = deepcopy(extra_params) - self._setup_logging() - # setup hooks self._register_hooks(auto=True) + # we need a ConfigLoader registered in order to be able to set up logging + self._setup_logging() @property def static_data(self) -> Dict[str, Any]: @@ -521,7 +521,11 @@ def _create_config_loader( # pylint: disable=no-self-use Instance of `ConfigLoader`. """ - return ConfigLoader(conf_paths) + hook_manager = get_hook_manager() + config_loader = hook_manager.hook.register_config_loader( # pylint: disable=no-member + conf_paths=conf_paths + ) + return config_loader or ConfigLoader(conf_paths) # for backwards compatibility def _get_config_loader(self) -> ConfigLoader: """A hook for changing the creation of a ConfigLoader instance. diff --git a/kedro/framework/hooks/manager.py b/kedro/framework/hooks/manager.py index eb781d85c3..7a724149ce 100644 --- a/kedro/framework/hooks/manager.py +++ b/kedro/framework/hooks/manager.py @@ -32,7 +32,7 @@ from pluggy import PluginManager from .markers import HOOK_NAMESPACE -from .specs import DataCatalogSpecs, NodeSpecs, PipelineSpecs +from .specs import DataCatalogSpecs, NodeSpecs, PipelineSpecs, RegistrationSpecs _hook_manager = None @@ -44,6 +44,7 @@ def _create_hook_manager() -> PluginManager: manager.add_hookspecs(NodeSpecs) manager.add_hookspecs(PipelineSpecs) manager.add_hookspecs(DataCatalogSpecs) + manager.add_hookspecs(RegistrationSpecs) return manager diff --git a/kedro/framework/hooks/specs.py b/kedro/framework/hooks/specs.py index f44503ea70..b59599b25f 100644 --- a/kedro/framework/hooks/specs.py +++ b/kedro/framework/hooks/specs.py @@ -31,8 +31,9 @@ [Pluggy's documentation](https://pluggy.readthedocs.io/en/stable/#specs) """ # pylint: disable=too-many-arguments -from typing import Any, Dict +from typing import Any, Dict, Iterable +from kedro.config import ConfigLoader from kedro.io import DataCatalog from kedro.pipeline import Pipeline from kedro.pipeline.node import Node @@ -156,16 +157,6 @@ def on_node_error( class PipelineSpecs: """Namespace that defines all specifications for a pipeline's lifecycle hooks.""" - @hook_spec - def register_pipelines(self) -> Dict[str, Pipeline]: - """Hook to be invoked to register a project's pipelines. - - Returns: - A mapping from a pipeline name to a ``Pipeline`` object. - - """ - pass - @hook_spec def before_pipeline_run( self, run_params: Dict[str, Any], pipeline: Pipeline, catalog: DataCatalog @@ -261,3 +252,28 @@ def on_pipeline_error( catalog: The ``DataCatalog`` used during the run. """ pass + + +class RegistrationSpecs: + """Namespace that defines all specifications for hooks registering + library components with a Kedro project. + """ + + @hook_spec + def register_pipelines(self) -> Dict[str, Pipeline]: + """Hook to be invoked to register a project's pipelines. + + Returns: + A mapping from a pipeline name to a ``Pipeline`` object. + + """ + pass + + @hook_spec(firstresult=True) + def register_config_loader(self, conf_paths: Iterable[str]) -> ConfigLoader: + """Hook to be invoked to register a project's config loader. + + Returns: + An instance of a ``ConfigLoader``. + """ + pass diff --git a/kedro/templates/project/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/hooks.py b/kedro/templates/project/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/hooks.py index cba1e17350..88f66e0bf0 100644 --- a/kedro/templates/project/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/hooks.py +++ b/kedro/templates/project/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/hooks.py @@ -27,8 +27,9 @@ # limitations under the License. """Project hooks.""" -from typing import Dict +from typing import Dict, Iterable +from kedro.config import ConfigLoader from kedro.framework.hooks import hook_impl from kedro.pipeline import Pipeline {%- if cookiecutter.include_example == "True" %} @@ -57,5 +58,9 @@ def register_pipelines(self) -> Dict[str, Pipeline]: } {%- else -%}return {"__default__": Pipeline([])}{%- endif %} + @hook_impl + def register_config_loader(self, conf_paths: Iterable[str]) -> ConfigLoader: + return ConfigLoader(conf_paths) + project_hooks = ProjectHooks() diff --git a/setup.cfg b/setup.cfg index f702990b17..d93b29e852 100644 --- a/setup.cfg +++ b/setup.cfg @@ -67,3 +67,4 @@ forbidden_modules = ignore_imports= kedro.runner.parallel_runner -> kedro.framework.context.context kedro.framework.context.context -> kedro.config + kedro.framework.hooks.specs -> kedro.config diff --git a/test_requirements.txt b/test_requirements.txt index 67724fdc94..ba3d7f2ece 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -31,6 +31,7 @@ pyarrow>=0.12.0, <1.0.0 pylint>=2.5.2, <3.0 pyspark~=2.2; python_version < '3.8' pytest-cov~=2.5 +pytest-lazy-fixture~=0.6.3 pytest-mock>=1.7.1,<2.0 pytest~=5.0 requests-mock~=1.6 diff --git a/tests/framework/context/test_context.py b/tests/framework/context/test_context.py index 96eb91a39b..915dc10e24 100644 --- a/tests/framework/context/test_context.py +++ b/tests/framework/context/test_context.py @@ -25,7 +25,7 @@ # # See the License for the specific language governing permissions and # limitations under the License. - +# pylint: disable=no-member import configparser import json import re @@ -40,19 +40,15 @@ from pandas.util.testing import assert_frame_equal from kedro import __version__ as kedro_version -from kedro.config import MissingConfigException +from kedro.config import ConfigLoader, MissingConfigException, TemplatedConfigLoader from kedro.extras.datasets.pandas import CSVDataSet -from kedro.framework.context import ( - KedroContext, - KedroContextError, - validate_source_path, -) +from kedro.framework.context import KedroContext, KedroContextError from kedro.framework.context.context import ( _convert_paths_to_absolute_posix, _is_relative_path, _validate_layers_for_transcoding, - get_static_project_data, ) +from kedro.framework.hooks import get_hook_manager, hook_impl from kedro.io.core import Version, generate_timestamp from kedro.pipeline import Pipeline, node from kedro.runner import ParallelRunner, SequentialRunner @@ -186,16 +182,6 @@ def bad_node(x): raise ValueError("Oh no!") -bad_pipeline_middle = Pipeline( - [ - node(identity, "cars", "boats", name="node1", tags=["tag1"]), - node(identity, "boats", "trains", name="node2"), - node(bad_node, "trains", "ships", name="nodes3"), - node(identity, "ships", "planes", name="node4"), - ], - tags="bad_pipeline", -) - expected_message_middle = ( "There are 2 nodes that have not run.\n" "You can resume the pipeline run by adding the following " @@ -204,16 +190,6 @@ def bad_node(x): ) -bad_pipeline_head = Pipeline( - [ - node(bad_node, "cars", "boats", name="node1", tags=["tag1"]), - node(identity, "boats", "trains", name="node2"), - node(identity, "trains", "ships", name="nodes3"), - node(identity, "ships", "planes", name="node4"), - ], - tags="bad_pipeline", -) - expected_message_head = ( "There are 4 nodes that have not run.\n" "You can resume the pipeline run by adding the following " @@ -227,18 +203,65 @@ def bad_node(x): } +def _create_pipelines(): + bad_pipeline_middle = Pipeline( + [ + node(identity, "cars", "boats", name="node1", tags=["tag1"]), + node(identity, "boats", "trains", name="node2"), + node(bad_node, "trains", "ships", name="nodes3"), + node(identity, "ships", "planes", name="node4"), + ], + tags="bad_pipeline", + ) + bad_pipeline_head = Pipeline( + [ + node(bad_node, "cars", "boats", name="node1", tags=["tag1"]), + node(identity, "boats", "trains", name="node2"), + node(identity, "trains", "ships", name="nodes3"), + node(identity, "ships", "planes", name="node4"), + ], + tags="bad_pipeline", + ) + default_pipeline = Pipeline( + [ + node(identity, "cars", "boats", name="node1", tags=["tag1"]), + node(identity, "boats", "trains", name="node2"), + node(identity, "trains", "ships", name="node3"), + node(identity, "ships", "planes", name="node4"), + ], + tags="pipeline", + ) + return { + "__default__": default_pipeline, + "empty": Pipeline([]), + "simple": Pipeline([node(identity, "cars", "boats")]), + "bad_pipeline_middle": bad_pipeline_middle, + "bad_pipeline_head": bad_pipeline_head, + } + + +class RegistrationHooks: + @hook_impl + def register_config_loader(self, conf_paths) -> ConfigLoader: + return ConfigLoader(conf_paths) + + @hook_impl + def register_pipelines(self) -> Dict[str, Pipeline]: + return _create_pipelines() + + class DummyContext(KedroContext): + hooks = (RegistrationHooks(),) + + +class DummyContextNoHooks(KedroContext): + def _create_config_loader( # pylint: disable=no-self-use + self, conf_paths + ) -> ConfigLoader: + return TemplatedConfigLoader(conf_paths) + def _get_pipelines(self) -> Dict[str, Pipeline]: - pipeline = Pipeline( - [ - node(identity, "cars", "boats", name="node1", tags=["tag1"]), - node(identity, "boats", "trains", name="node2"), - node(identity, "trains", "ships", name="node3"), - node(identity, "ships", "planes", name="node4"), - ], - tags="pipeline", - ) - return {"__default__": pipeline} + return _create_pipelines() @pytest.fixture(params=[None]) @@ -254,12 +277,39 @@ def mocked_logging(mocker): @pytest.fixture -def dummy_context( +def dummy_context_with_hooks( tmp_path, prepare_project_dir, env, extra_params ): # pylint: disable=unused-argument return DummyContext(str(tmp_path), env=env, extra_params=extra_params) +@pytest.fixture +def dummy_context_no_hooks( + tmp_path, prepare_project_dir, env, extra_params +): # pylint: disable=unused-argument + return DummyContextNoHooks(str(tmp_path), env=env, extra_params=extra_params) + + +@pytest.fixture( + params=[ + pytest.lazy_fixture("dummy_context_with_hooks"), + pytest.lazy_fixture("dummy_context_no_hooks"), + ] +) +def dummy_context(request): + # for backwards-compatibility, test with and without registration hooks + return request.param + + +@pytest.fixture(autouse=True) +def clear_hook_manager(): + yield + hook_manager = get_hook_manager() + plugins = hook_manager.get_plugins() + for plugin in plugins: + hook_manager.unregister(plugin) + + class TestKedroContext: def test_attributes(self, tmp_path, dummy_context): assert dummy_context.project_name == kedro_yml_payload["project_name"] @@ -326,14 +376,21 @@ def test_nested_params(self, param, expected, dummy_context): [None, {}, {"foo": "bar", "baz": [1, 2], "qux": None}], indirect=True, ) - def test_params_missing(self, dummy_context, mocker, extra_params): - mock_config_loader = mocker.patch.object(DummyContext, "config_loader") + @pytest.mark.parametrize( + "context_class, context_fixture", + [ + (DummyContext, pytest.lazy_fixture("dummy_context_with_hooks")), + (DummyContextNoHooks, pytest.lazy_fixture("dummy_context_no_hooks")), + ], + ) + def test_params_missing(self, mocker, extra_params, context_class, context_fixture): + mock_config_loader = mocker.patch.object(context_class, "config_loader") mock_config_loader.get.side_effect = MissingConfigException("nope") extra_params = extra_params or {} pattern = "Parameters not found in your Kedro project config" with pytest.warns(UserWarning, match=pattern): - actual = dummy_context.params + actual = context_fixture.params assert actual == extra_params def test_config_loader(self, dummy_context): @@ -355,41 +412,60 @@ def test_default_env(self, dummy_context): @pytest.mark.parametrize( "invalid_version", ["0.13.0", "10.0", "101.1", "100.0", "-0"] ) - def test_invalid_version(self, mocker, invalid_version, dummy_context): - mocker.patch.object(DummyContext, "project_version", invalid_version) + @pytest.mark.parametrize( + "context_class,context_fixture", + [ + (DummyContext, pytest.lazy_fixture("dummy_context_with_hooks")), + (DummyContextNoHooks, pytest.lazy_fixture("dummy_context_no_hooks")), + ], + ) + def test_invalid_version( + self, mocker, invalid_version, context_class, context_fixture + ): + mocker.patch.object(context_class, "project_version", invalid_version) pattern = ( f"Your Kedro project version {invalid_version} does not match " f"Kedro package version {kedro_version} you are running." ) with pytest.raises(KedroContextError, match=re.escape(pattern)): - DummyContext(dummy_context.project_path) + context_class(context_fixture.project_path) @pytest.mark.parametrize("env", ["custom_env"], indirect=True) def test_custom_env(self, dummy_context, env): assert dummy_context.env == env - def test_missing_parameters(self, tmp_path, dummy_context): + @pytest.mark.parametrize( + "context_class,context_fixture", + [ + (DummyContext, pytest.lazy_fixture("dummy_context_with_hooks")), + (DummyContextNoHooks, pytest.lazy_fixture("dummy_context_no_hooks")), + ], + ) + def test_missing_parameters(self, tmp_path, context_class, context_fixture): parameters = tmp_path / "conf" / "base" / "parameters.json" parameters.unlink() pattern = "Parameters not found in your Kedro project config." with pytest.warns(UserWarning, match=re.escape(pattern)): - DummyContext( # pylint: disable=expression-not-assigned - dummy_context.project_path - ).catalog + _ = context_class(context_fixture.project_path).catalog - def test_missing_credentials(self, dummy_context): + @pytest.mark.parametrize( + "context_class,context_fixture", + [ + (DummyContext, pytest.lazy_fixture("dummy_context_with_hooks")), + (DummyContextNoHooks, pytest.lazy_fixture("dummy_context_no_hooks")), + ], + ) + def test_missing_credentials(self, context_class, context_fixture): env_credentials = ( - dummy_context.project_path / "conf" / "local" / "credentials.yml" + context_fixture.project_path / "conf" / "local" / "credentials.yml" ) env_credentials.unlink() pattern = "Credentials not found in your Kedro project config." with pytest.warns(UserWarning, match=re.escape(pattern)): - DummyContext( # pylint: disable=expression-not-assigned - dummy_context.project_path - ).catalog + _ = context_class(context_fixture.project_path).catalog def test_pipeline(self, dummy_context): assert dummy_context.pipeline.nodes[0].inputs == ["cars"] @@ -398,7 +474,7 @@ def test_pipeline(self, dummy_context): assert dummy_context.pipeline.nodes[1].outputs == ["trains"] def test_pipelines(self, dummy_context): - assert len(dummy_context.pipelines) == 1 + assert len(dummy_context.pipelines) == 5 assert len(dummy_context.pipelines["__default__"].nodes) == 4 def test_setup_logging_using_absolute_path(self, dummy_context, mocked_logging): @@ -538,14 +614,9 @@ def test_run_from_inputs(self, dummy_context, dummy_dataframe, caplog): assert "Pipeline execution completed successfully." in log_msgs @pytest.mark.usefixtures("prepare_project_dir") - def test_run_load_versions(self, tmp_path, dummy_dataframe, mocker): - mocker.patch.object( - DummyContext, - "_get_pipelines", - return_value={"__default__": Pipeline([node(identity, "cars", "boats")])}, - ) - - context = DummyContext(tmp_path) + @pytest.mark.parametrize("context_class", [DummyContext, DummyContextNoHooks]) + def test_run_load_versions(self, tmp_path, dummy_dataframe, context_class): + context = context_class(tmp_path) filepath = (context.project_path / "cars.csv").as_posix() old_save_version = generate_timestamp() @@ -567,50 +638,42 @@ def test_run_load_versions(self, tmp_path, dummy_dataframe, mocker): new_csv_data_set.save(dummy_dataframe) load_versions = {"cars": old_save_version} - context.run(load_versions=load_versions) + context.run(load_versions=load_versions, pipeline_name="simple") assert not context.catalog.load("boats").equals(dummy_dataframe) assert context.catalog.load("boats").equals(old_df) @pytest.mark.usefixtures("prepare_project_dir") - def test_run_with_empty_pipeline(self, tmp_path, mocker): - mocker.patch.object( - DummyContext, "_get_pipelines", return_value={"__default__": Pipeline([])}, - ) - - context = DummyContext(tmp_path) + @pytest.mark.parametrize("context_class", [DummyContext, DummyContextNoHooks]) + def test_run_with_empty_pipeline(self, tmp_path, context_class): + context = context_class(tmp_path) assert context.project_name == kedro_yml_payload["project_name"] assert context.project_version == kedro_yml_payload["project_version"] with pytest.raises(KedroContextError, match="Pipeline contains no nodes"): - context.run() + context.run(pipeline_name="empty") @pytest.mark.parametrize( - "context_pipeline,expected_message", + "pipeline_name,expected_message", [ - (bad_pipeline_middle, expected_message_middle), - (bad_pipeline_head, expected_message_head), + ("bad_pipeline_middle", expected_message_middle), + ("bad_pipeline_head", expected_message_head), ], # pylint: disable=too-many-arguments ) + @pytest.mark.parametrize("context_class", [DummyContext, DummyContextNoHooks]) @pytest.mark.usefixtures("prepare_project_dir") def test_run_failure_prompts_resume_command( self, - mocker, tmp_path, dummy_dataframe, caplog, - context_pipeline, + pipeline_name, expected_message, + context_class, ): - mocker.patch.object( - DummyContext, - "_get_pipelines", - return_value={"__default__": context_pipeline}, - ) - - bad_context = DummyContext(tmp_path) + bad_context = context_class(tmp_path) bad_context.catalog.save("cars", dummy_dataframe) with pytest.raises(ValueError, match="Oh no"): - bad_context.run() + bad_context.run(pipeline_name=pipeline_name) actual_messages = [ record.getMessage() @@ -641,18 +704,19 @@ def test_run_with_extra_params( assert mock_journal.call_args[0][0]["extra_params"] == extra_params @pytest.mark.usefixtures("prepare_project_dir") + @pytest.mark.parametrize("context_class", [DummyContext, DummyContextNoHooks]) def test_run_with_save_version_as_run_id( - self, mocker, tmp_path, dummy_dataframe, caplog + self, mocker, tmp_path, dummy_dataframe, caplog, context_class ): """Test that the default behaviour, with run_id set to None, creates a journal record with the run_id the same as save_version. """ save_version = "2020-01-01T00.00.00.000Z" mocked_get_save_version = mocker.patch.object( - DummyContext, "_get_save_version", return_value=save_version + context_class, "_get_save_version", return_value=save_version ) - context = DummyContext(tmp_path) + context = context_class(tmp_path) context.catalog.save("cars", dummy_dataframe) context.run(load_versions={"boats": save_version}) @@ -665,13 +729,16 @@ def test_run_with_save_version_as_run_id( assert json.loads(log_msg)["run_id"] == save_version @pytest.mark.usefixtures("prepare_project_dir") - def test_run_with_custom_run_id(self, mocker, tmp_path, dummy_dataframe, caplog): + @pytest.mark.parametrize("context_class", [DummyContext, DummyContextNoHooks]) + def test_run_with_custom_run_id( + self, mocker, tmp_path, dummy_dataframe, caplog, context_class + ): run_id = "001" mocked_get_run_id = mocker.patch.object( - DummyContext, "_get_run_id", return_value=run_id + context_class, "_get_run_id", return_value=run_id ) - context = DummyContext(tmp_path) + context = context_class(tmp_path) context.catalog.save("cars", dummy_dataframe) context.run() @@ -828,119 +895,3 @@ def test_validate_layers_error(layers, conflicting_datasets, mocker): ) with pytest.raises(ValueError, match=re.escape(pattern)): _validate_layers_for_transcoding(mock_catalog) - - -class TestValidateSourcePath: - @pytest.mark.parametrize( - "source_dir", [".", "src", "./src", "src/nested", "src/nested/nested"] - ) - def test_valid_source_path(self, tmp_path, source_dir): - source_path = (tmp_path / source_dir).resolve() - source_path.mkdir(parents=True, exist_ok=True) - validate_source_path(source_path, tmp_path.resolve()) - - @pytest.mark.parametrize("source_dir", ["..", "src/../..", "~"]) - def test_invalid_source_path(self, tmp_path, source_dir): - source_dir = Path(source_dir).expanduser() - source_path = (tmp_path / source_dir).resolve() - source_path.mkdir(parents=True, exist_ok=True) - - pattern = re.escape( - f"Source path '{source_path}' has to be relative to your project root " - f"'{tmp_path.resolve()}'" - ) - with pytest.raises(KedroContextError, match=pattern): - validate_source_path(source_path, tmp_path.resolve()) - - def test_non_existent_source_path(self, tmp_path): - source_path = (tmp_path / "non_existent").resolve() - - pattern = re.escape(f"Source path '{source_path}' cannot be found.") - with pytest.raises(KedroContextError, match=pattern): - validate_source_path(source_path, tmp_path.resolve()) - - -class TestGetStaticProjectData: - project_path = Path.cwd() - - def test_no_config_files(self, mocker): - mocker.patch.object(Path, "is_file", return_value=False) - - pattern = ( - f"Could not find any of configuration files '.kedro.yml, pyproject.toml' " - f"in {self.project_path}" - ) - with pytest.raises(KedroContextError, match=re.escape(pattern)): - get_static_project_data(self.project_path) - - def test_kedro_yml_invalid_format(self, tmp_path): - """Test for loading context from an invalid path. """ - kedro_yml_path = tmp_path / ".kedro.yml" - kedro_yml_path.write_text("!!") # Invalid YAML - pattern = "Failed to parse '.kedro.yml' file" - with pytest.raises(KedroContextError, match=re.escape(pattern)): - get_static_project_data(str(tmp_path)) - - def test_toml_invalid_format(self, tmp_path): - """Test for loading context from an invalid path. """ - toml_path = tmp_path / "pyproject.toml" - toml_path.write_text("!!") # Invalid TOML - pattern = "Failed to parse 'pyproject.toml' file" - with pytest.raises(KedroContextError, match=re.escape(pattern)): - get_static_project_data(str(tmp_path)) - - def test_valid_yml_file_exists(self, mocker): - # Both yml and toml files exist - mocker.patch.object(Path, "is_file", return_value=True) - mocker.patch("anyconfig.load", return_value={}) - - static_data = get_static_project_data(self.project_path) - - # Using default source directory - assert static_data == { - "source_dir": self.project_path / "src", - "config_file": self.project_path / ".kedro.yml", - } - - def test_valid_toml_file(self, mocker): - # .kedro.yml doesn't exists - mocker.patch.object(Path, "is_file", side_effect=[False, True]) - mocker.patch("anyconfig.load", return_value={"tool": {"kedro": {}}}) - - static_data = get_static_project_data(self.project_path) - - # Using default source directory - assert static_data == { - "source_dir": self.project_path / "src", - "config_file": self.project_path / "pyproject.toml", - } - - def test_toml_file_without_kedro_section(self, mocker): - mocker.patch.object(Path, "is_file", side_effect=[False, True]) - mocker.patch("anyconfig.load", return_value={}) - - pattern = "There's no '[tool.kedro]' section in the 'pyproject.toml'." - - with pytest.raises(KedroContextError, match=re.escape(pattern)): - get_static_project_data(self.project_path) - - def test_source_dir_specified_in_yml(self, mocker): - mocker.patch.object(Path, "is_file", side_effect=[True, False]) - source_dir = "test_dir" - mocker.patch("anyconfig.load", return_value={"source_dir": source_dir}) - - static_data = get_static_project_data(self.project_path) - - assert static_data["source_dir"] == self.project_path / source_dir - - def test_source_dir_specified_in_toml(self, mocker): - mocker.patch.object(Path, "is_file", side_effect=[False, True]) - source_dir = "test_dir" - mocker.patch( - "anyconfig.load", - return_value={"tool": {"kedro": {"source_dir": source_dir}}}, - ) - - static_data = get_static_project_data(self.project_path) - - assert static_data["source_dir"] == self.project_path / source_dir diff --git a/tests/framework/context/test_static_data.py b/tests/framework/context/test_static_data.py new file mode 100644 index 0000000000..ba663b62de --- /dev/null +++ b/tests/framework/context/test_static_data.py @@ -0,0 +1,153 @@ +# Copyright 2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +import re +from pathlib import Path + +import pytest + +from kedro.framework.context import ( + KedroContextError, + get_static_project_data, + validate_source_path, +) + + +class TestValidateSourcePath: + @pytest.mark.parametrize( + "source_dir", [".", "src", "./src", "src/nested", "src/nested/nested"] + ) + def test_valid_source_path(self, tmp_path, source_dir): + source_path = (tmp_path / source_dir).resolve() + source_path.mkdir(parents=True, exist_ok=True) + validate_source_path(source_path, tmp_path.resolve()) + + @pytest.mark.parametrize("source_dir", ["..", "src/../..", "~"]) + def test_invalid_source_path(self, tmp_path, source_dir): + source_dir = Path(source_dir).expanduser() + source_path = (tmp_path / source_dir).resolve() + source_path.mkdir(parents=True, exist_ok=True) + + pattern = re.escape( + f"Source path '{source_path}' has to be relative to your project root " + f"'{tmp_path.resolve()}'" + ) + with pytest.raises(KedroContextError, match=pattern): + validate_source_path(source_path, tmp_path.resolve()) + + def test_non_existent_source_path(self, tmp_path): + source_path = (tmp_path / "non_existent").resolve() + + pattern = re.escape(f"Source path '{source_path}' cannot be found.") + with pytest.raises(KedroContextError, match=pattern): + validate_source_path(source_path, tmp_path.resolve()) + + +class TestGetStaticProjectData: + project_path = Path.cwd() + + def test_no_config_files(self, mocker): + mocker.patch.object(Path, "is_file", return_value=False) + + pattern = ( + f"Could not find any of configuration files '.kedro.yml, pyproject.toml' " + f"in {self.project_path}" + ) + with pytest.raises(KedroContextError, match=re.escape(pattern)): + get_static_project_data(self.project_path) + + def test_kedro_yml_invalid_format(self, tmp_path): + """Test for loading context from an invalid path. """ + kedro_yml_path = tmp_path / ".kedro.yml" + kedro_yml_path.write_text("!!") # Invalid YAML + pattern = "Failed to parse '.kedro.yml' file" + with pytest.raises(KedroContextError, match=re.escape(pattern)): + get_static_project_data(str(tmp_path)) + + def test_toml_invalid_format(self, tmp_path): + """Test for loading context from an invalid path. """ + toml_path = tmp_path / "pyproject.toml" + toml_path.write_text("!!") # Invalid TOML + pattern = "Failed to parse 'pyproject.toml' file" + with pytest.raises(KedroContextError, match=re.escape(pattern)): + get_static_project_data(str(tmp_path)) + + def test_valid_yml_file_exists(self, mocker): + # Both yml and toml files exist + mocker.patch.object(Path, "is_file", return_value=True) + mocker.patch("anyconfig.load", return_value={}) + + static_data = get_static_project_data(self.project_path) + + # Using default source directory + assert static_data == { + "source_dir": self.project_path / "src", + "config_file": self.project_path / ".kedro.yml", + } + + def test_valid_toml_file(self, mocker): + # .kedro.yml doesn't exists + mocker.patch.object(Path, "is_file", side_effect=[False, True]) + mocker.patch("anyconfig.load", return_value={"tool": {"kedro": {}}}) + + static_data = get_static_project_data(self.project_path) + + # Using default source directory + assert static_data == { + "source_dir": self.project_path / "src", + "config_file": self.project_path / "pyproject.toml", + } + + def test_toml_file_without_kedro_section(self, mocker): + mocker.patch.object(Path, "is_file", side_effect=[False, True]) + mocker.patch("anyconfig.load", return_value={}) + + pattern = "There's no '[tool.kedro]' section in the 'pyproject.toml'." + + with pytest.raises(KedroContextError, match=re.escape(pattern)): + get_static_project_data(self.project_path) + + def test_source_dir_specified_in_yml(self, mocker): + mocker.patch.object(Path, "is_file", side_effect=[True, False]) + source_dir = "test_dir" + mocker.patch("anyconfig.load", return_value={"source_dir": source_dir}) + + static_data = get_static_project_data(self.project_path) + + assert static_data["source_dir"] == self.project_path / source_dir + + def test_source_dir_specified_in_toml(self, mocker): + mocker.patch.object(Path, "is_file", side_effect=[False, True]) + source_dir = "test_dir" + mocker.patch( + "anyconfig.load", + return_value={"tool": {"kedro": {"source_dir": source_dir}}}, + ) + + static_data = get_static_project_data(self.project_path) + + assert static_data["source_dir"] == self.project_path / source_dir diff --git a/tests/framework/hooks/test_context_hooks.py b/tests/framework/hooks/test_context_hooks.py index fe3568fbd6..3250d0d455 100644 --- a/tests/framework/hooks/test_context_hooks.py +++ b/tests/framework/hooks/test_context_hooks.py @@ -31,13 +31,14 @@ from logging.handlers import QueueHandler, QueueListener from multiprocessing import Queue from pathlib import Path -from typing import Any, Dict, List, Union +from typing import Any, Dict, Iterable, List, Union import pandas as pd import pytest import yaml from kedro import __version__ +from kedro.config import ConfigLoader from kedro.framework.context import KedroContext from kedro.framework.context.context import _convert_paths_to_absolute_posix from kedro.framework.hooks import hook_impl @@ -289,6 +290,11 @@ def register_pipelines(self) -> Dict[str, Pipeline]: self.logger.info("Registering pipelines") return {"__default__": context_pipeline, "de": context_pipeline} + @hook_impl + def register_config_loader(self, conf_paths: Iterable[str]) -> ConfigLoader: + self.logger.info("Registering config loader", extra={"conf_paths": conf_paths}) + return ConfigLoader(conf_paths) + class RegistrationHooks: @hook_impl @@ -376,15 +382,15 @@ def _get_pipelines(self) -> Dict[str, Pipeline]: return BrokenContextWithHooks(tmp_path, env="local") -class TestKedroContextHooks: - @staticmethod - def _assert_hook_call_record_has_expected_parameters( - call_record: logging.LogRecord, expected_parameters: List[str] - ): - """Assert the given call record has all expected parameters.""" - for param in expected_parameters: - assert hasattr(call_record, param) +def _assert_hook_call_record_has_expected_parameters( + call_record: logging.LogRecord, expected_parameters: List[str] +): + """Assert the given call record has all expected parameters.""" + for param in expected_parameters: + assert hasattr(call_record, param) + +class TestKedroContextHooks: def test_calling_register_hooks_multiple_times_should_not_raise( self, context_with_hooks ): @@ -417,10 +423,13 @@ def test_after_catalog_created_hook_is_called(self, context_with_hooks, caplog): catalog = context_with_hooks.catalog config_loader = context_with_hooks.config_loader relevant_records = [ - r for r in caplog.records if r.name == LoggingHooks.handler_name + r + for r in caplog.records + if r.name == LoggingHooks.handler_name + and r.getMessage() == "Catalog created" ] + assert len(relevant_records) == 1 record = relevant_records[0] - assert record.getMessage() == "Catalog created" assert record.catalog == catalog assert record.conf_creds == config_loader.get("credentials*") assert record.conf_catalog == _convert_paths_to_absolute_posix( @@ -448,7 +457,7 @@ def test_before_and_after_pipeline_run_hooks_are_called( assert len(before_pipeline_run_calls) == 1 call_record = before_pipeline_run_calls[0] assert call_record.pipeline.describe() == context_with_hooks.pipeline.describe() - self._assert_hook_call_record_has_expected_parameters( + _assert_hook_call_record_has_expected_parameters( call_record, ["pipeline", "catalog", "run_params"] ) @@ -460,7 +469,7 @@ def test_before_and_after_pipeline_run_hooks_are_called( ] assert len(after_pipeline_run_calls) == 1 call_record = after_pipeline_run_calls[0] - self._assert_hook_call_record_has_expected_parameters( + _assert_hook_call_record_has_expected_parameters( call_record, ["pipeline", "catalog", "run_params"] ) assert call_record.pipeline.describe() == context_with_hooks.pipeline.describe() @@ -476,7 +485,7 @@ def test_on_pipeline_error_hook_is_called(self, broken_context_with_hooks, caplo ] assert len(on_pipeline_error_calls) == 1 call_record = on_pipeline_error_calls[0] - self._assert_hook_call_record_has_expected_parameters( + _assert_hook_call_record_has_expected_parameters( call_record, ["error", "run_params", "pipeline", "catalog"] ) expected_error = ValueError("broken") @@ -493,7 +502,7 @@ def test_on_node_error_hook_is_called_with_sequential_runner( ] assert len(on_node_error_calls) == 1 call_record = on_node_error_calls[0] - self._assert_hook_call_record_has_expected_parameters( + _assert_hook_call_record_has_expected_parameters( call_record, ["error", "node", "catalog", "inputs", "is_async", "run_id"] ) expected_error = ValueError("broken") @@ -511,7 +520,7 @@ def test_before_and_after_node_run_hooks_are_called_with_sequential_runner( ] assert len(before_node_run_calls) == 1 call_record = before_node_run_calls[0] - self._assert_hook_call_record_has_expected_parameters( + _assert_hook_call_record_has_expected_parameters( call_record, ["node", "catalog", "inputs", "is_async", "run_id"] ) # sanity check a couple of important parameters @@ -524,7 +533,7 @@ def test_before_and_after_node_run_hooks_are_called_with_sequential_runner( ] assert len(after_node_run_calls) == 1 call_record = after_node_run_calls[0] - self._assert_hook_call_record_has_expected_parameters( + _assert_hook_call_record_has_expected_parameters( call_record, ["node", "catalog", "inputs", "outputs", "is_async", "run_id"] ) # sanity check a couple of important parameters @@ -565,7 +574,7 @@ def handle(self, record): assert len(on_node_error_records) == 2 for call_record in on_node_error_records: - self._assert_hook_call_record_has_expected_parameters( + _assert_hook_call_record_has_expected_parameters( call_record, ["error", "node", "catalog", "inputs", "is_async", "run_id"], ) @@ -616,6 +625,8 @@ def handle(self, record): assert record.node.name in ["node1", "node2"] assert set(record.outputs.keys()) <= {"planes", "ships"} + +class TestRegistrationHooks: def test_register_pipelines_is_called( self, context_with_hooks, dummy_dataframe, caplog ): @@ -630,7 +641,7 @@ def test_register_pipelines_is_called( ] assert len(register_pipelines_calls) == 1 call_record = register_pipelines_calls[0] - self._assert_hook_call_record_has_expected_parameters(call_record, []) + _assert_hook_call_record_has_expected_parameters(call_record, []) expected_pipelines = {"__default__": context_pipeline, "de": context_pipeline} assert context_with_hooks.pipelines == expected_pipelines @@ -650,3 +661,22 @@ def test_register_pipelines_with_duplicate_entries( key: context_pipeline for key in ("__default__", "de", "pipe") } assert context_with_duplicate_hooks.pipelines == expected_pipelines + + def test_register_config_loader_is_called(self, context_with_hooks, caplog): + _ = context_with_hooks.config_loader + + relevant_records = [ + r for r in caplog.records if r.name == LoggingHooks.handler_name + ] + assert len(relevant_records) == 1 + record = relevant_records[0] + assert record.getMessage() == "Registering config loader" + expected_conf_paths = [ + str( + context_with_hooks.project_path / context_with_hooks.CONF_ROOT / "base" + ), + str( + context_with_hooks.project_path / context_with_hooks.CONF_ROOT / "local" + ), + ] + assert record.conf_paths == expected_conf_paths