Skip to content

Commit

Permalink
Trawl the pipelines directory and return as keys
Browse files Browse the repository at this point in the history
Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>
  • Loading branch information
deepyaman committed Jul 13, 2022
1 parent 3f23527 commit 5485c40
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
7 changes: 6 additions & 1 deletion kedro/framework/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
configure a Kedro project and access its settings."""
# pylint: disable=redefined-outer-name,unused-argument,global-statement
import importlib
import importlib.resources
import logging.config
import operator
import sys
Expand Down Expand Up @@ -269,4 +270,8 @@ def find_pipelines() -> Dict[str, Pipeline]:
function, the ``create_pipeline`` function does not return a
``Pipeline`` object, or if the module import fails up front.
"""
return {"__default__": pipeline([])}
pipelines = {"__default__": pipeline([])}
for pipeline_name in importlib.resources.contents(f"{PACKAGE_NAME}.pipelines"):
importlib.import_module(f"{PACKAGE_NAME}.pipelines", pipeline_name)
pipelines[pipeline_name] = pipeline([])
return pipelines
16 changes: 9 additions & 7 deletions tests/framework/project/test_pipeline_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

import pytest

from kedro.framework.project import find_pipelines
from kedro.framework.project import configure_project, find_pipelines


@pytest.fixture
def mock_package_with_pipelines(tmp_path, request):
pipelines_dir = tmp_path / "test_package" / "pipelines"
def mock_package_name_with_pipelines(tmp_path, request):
package_name = "test_package"
pipelines_dir = tmp_path / package_name / "pipelines"
pipelines_dir.mkdir(parents=True)
for pipeline_name in request.param:
pipeline_dir = pipelines_dir / pipeline_name
Expand All @@ -25,7 +26,7 @@ def create_pipeline(**kwargs) -> Pipeline:
)
)
sys.path.insert(0, str(tmp_path))
yield
yield package_name
sys.path.pop(0)


Expand All @@ -35,14 +36,15 @@ def pipeline_names(request):


@pytest.mark.parametrize(
"mock_package_with_pipelines,pipeline_names",
[(x, x) for x in [set()]],
"mock_package_name_with_pipelines,pipeline_names",
[(x, x) for x in [set(), {"my_pipeline"}]],
indirect=True,
)
def test_pipelines_without_configure_project_is_empty(
mock_package_with_pipelines, # pylint: disable=unused-argument
mock_package_name_with_pipelines, # pylint: disable=unused-argument
pipeline_names,
):
configure_project(mock_package_name_with_pipelines)
pipelines = find_pipelines()
assert set(pipelines) == pipeline_names | {"__default__"}
assert sum(pipelines.values()).outputs() == pipeline_names

0 comments on commit 5485c40

Please sign in to comment.