Skip to content

Commit

Permalink
Handle case where create_pipeline is not exposed
Browse files Browse the repository at this point in the history
  • Loading branch information
deepyaman committed Jul 13, 2022
1 parent 2e9ae0e commit bbe9763
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
14 changes: 12 additions & 2 deletions kedro/framework/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging.config
import operator
import sys
import warnings
from collections import UserDict
from collections.abc import MutableMapping
from pathlib import Path
Expand Down Expand Up @@ -272,6 +273,15 @@ def find_pipelines() -> Dict[str, Pipeline]:
"""
pipelines_dict = {"__default__": pipeline([])}
for pipeline_name in importlib.resources.contents(f"{PACKAGE_NAME}.pipelines"):
importlib.import_module(f"{PACKAGE_NAME}.pipelines", pipeline_name)
pipelines_dict[pipeline_name] = pipeline([])
pipeline_module = importlib.import_module(
f"{PACKAGE_NAME}.pipelines.{pipeline_name}"
)
if not hasattr(pipeline_module, "create_pipeline"):
warnings.warn(
f"The '{pipeline_module.__name__}' module does not "
f"expose a 'create_pipeline' function, so no pipelines "
f"defined therein will be returned by 'find_pipelines'."
)
continue
pipelines_dict[pipeline_name] = getattr(pipeline_module, "create_pipeline")()
return pipelines_dict
27 changes: 26 additions & 1 deletion tests/framework/project/test_pipeline_discovery.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import textwrap
from pathlib import Path

import pytest

Expand Down Expand Up @@ -40,11 +41,35 @@ def pipeline_names(request):
[(x, x) for x in [set(), {"my_pipeline"}]],
indirect=True,
)
def test_pipelines_without_configure_project_is_empty(
def test_find_pipelines(
mock_package_name_with_pipelines, # pylint: disable=unused-argument
pipeline_names,
):
configure_project(mock_package_name_with_pipelines)
pipelines = find_pipelines()
assert set(pipelines) == pipeline_names | {"__default__"}
assert sum(pipelines.values()).outputs() == pipeline_names


@pytest.mark.parametrize(
"mock_package_name_with_pipelines,pipeline_names",
[(x, x) for x in [set(), {"good_pipeline"}]],
indirect=True,
)
def test_find_pipelines_skips_modules_without_create_pipelines_function(
mock_package_name_with_pipelines, # pylint: disable=unused-argument
pipeline_names,
):
# Create a module without `create_pipelines` in the `pipelines` dir.
pipelines_dir = Path(sys.path[0]) / mock_package_name_with_pipelines / "pipelines"
pipeline_dir = pipelines_dir / "bad_pipeline"
pipeline_dir.mkdir()
(pipeline_dir / "__init__.py").touch()

configure_project(mock_package_name_with_pipelines)
with pytest.warns(
UserWarning, match="module does not expose a 'create_pipeline' function"
):
pipelines = find_pipelines()
assert set(pipelines) == pipeline_names | {"__default__"}
assert sum(pipelines.values()).outputs() == pipeline_names

0 comments on commit bbe9763

Please sign in to comment.