diff --git a/dependency/requirements.txt b/dependency/requirements.txt index ad4e428f49..a6a0b36b49 100644 --- a/dependency/requirements.txt +++ b/dependency/requirements.txt @@ -13,6 +13,7 @@ importlib_resources>=1.3 # The `files()` API was introduced in `importlib_resou jmespath>=0.9.5, <1.0 more_itertools~=9.0 omegaconf~=2.3 +parse~=1.19.0 pip-tools~=6.5 pluggy~=1.0.0 PyYAML>=4.2, <7.0 diff --git a/kedro/framework/cli/catalog.py b/kedro/framework/cli/catalog.py index 5fd64fdd43..458c2d5d88 100644 --- a/kedro/framework/cli/catalog.py +++ b/kedro/framework/cli/catalog.py @@ -52,10 +52,7 @@ def list_datasets(metadata: ProjectMetadata, pipeline, env): mentioned = "Datasets mentioned in pipeline" session = _create_session(metadata.package_name, env=env) - context = session.load_context() - datasets_meta = context.catalog._data_sets # pylint: disable=protected-access - catalog_ds = set(context.catalog.list()) - + loaded_catalog = session.load_context().catalog target_pipelines = pipeline or pipelines.keys() result = {} @@ -63,12 +60,14 @@ def list_datasets(metadata: ProjectMetadata, pipeline, env): pl_obj = pipelines.get(pipe) if pl_obj: pipeline_ds = pl_obj.data_sets() + loaded_catalog.resolve(pipeline_ds) else: existing_pls = ", ".join(sorted(pipelines.keys())) raise KedroCliError( f"'{pipe}' pipeline not found! Existing pipelines: {existing_pls}" ) - + catalog_ds = set(loaded_catalog.list()) + datasets_meta = loaded_catalog._data_sets # pylint: disable=protected-access unused_ds = catalog_ds - pipeline_ds default_ds = pipeline_ds - catalog_ds used_ds = catalog_ds - unused_ds diff --git a/kedro/framework/session/session.py b/kedro/framework/session/session.py index 9085bc9abf..4d44fa2554 100644 --- a/kedro/framework/session/session.py +++ b/kedro/framework/session/session.py @@ -414,6 +414,11 @@ def run( # pylint: disable=too-many-arguments,too-many-locals load_versions=load_versions, ) + # Get a set of all datasets used by the pipeline + named_datasets = filtered_pipeline.data_sets() + # Resolve the datasets against patterns in the catalog + catalog.resolve(named_datasets) + # Run the runner hook_manager = self._hook_manager runner = runner or SequentialRunner() diff --git a/kedro/io/data_catalog.py b/kedro/io/data_catalog.py index 57c31682c0..2960a5bb55 100644 --- a/kedro/io/data_catalog.py +++ b/kedro/io/data_catalog.py @@ -13,6 +13,8 @@ from collections import defaultdict from typing import Any +from parse import parse + from kedro.io.core import ( AbstractDataSet, AbstractVersionedDataSet, @@ -94,6 +96,26 @@ def _sub_nonword_chars(data_set_name: str) -> str: return re.sub(WORDS_REGEX_PATTERN, "__", data_set_name) +def _specificity(pattern: str) -> int: + """Helper function to check length of exactly matched characters not inside brackets + Example - + specificity("{namespace}.companies") = 10 + specificity("{namespace}.{dataset}") = 1 + specificity("france.companies") = 16 + + Args: + pattern: + + Returns: + + """ + pattern_variables = parse(pattern, pattern).named + for k in pattern_variables: + pattern_variables[k] = "" + specific_characters = pattern.format(**pattern_variables) + return -len(specific_characters) + + class _FrozenDatasets: """Helper class to access underlying loaded datasets""" @@ -141,6 +163,7 @@ def __init__( data_sets: dict[str, AbstractDataSet] = None, feed_dict: dict[str, Any] = None, layers: dict[str, set[str]] = None, + dataset_patterns: dict[str, Any] = None, ) -> None: """``DataCatalog`` stores instances of ``AbstractDataSet`` implementations to provide ``load`` and ``save`` capabilities from @@ -170,6 +193,14 @@ def __init__( self._data_sets = dict(data_sets or {}) self.datasets = _FrozenDatasets(self._data_sets) self.layers = layers + # Keep a record of all patterns in the catalog. + # {dataset pattern name : dataset pattern body} + self.dataset_patterns = dict(dataset_patterns or {}) + # Sort all patterns according to parsing rules + self.sorted_patterns = sorted( + self.dataset_patterns.keys(), + key=lambda x: (_specificity(x), -x.count("{"), x), + ) # import the feed dict if feed_dict: @@ -257,6 +288,7 @@ class to be loaded is specified with the key ``type`` and their >>> catalog.save("boats", df) """ data_sets = {} + dataset_patterns = {} catalog = copy.deepcopy(catalog) or {} credentials = copy.deepcopy(credentials) or {} save_version = save_version or generate_timestamp() @@ -271,17 +303,26 @@ class to be loaded is specified with the key ``type`` and their layers: dict[str, set[str]] = defaultdict(set) for ds_name, ds_config in catalog.items(): - ds_layer = ds_config.pop("layer", None) - if ds_layer is not None: - layers[ds_layer].add(ds_name) + # Assume that any name with } in it is a dataset factory to be matched. + if "}" in ds_name: + # Add each factory to the dataset_patterns dict. + dataset_patterns[ds_name] = ds_config + else: + ds_layer = ds_config.pop("layer", None) + if ds_layer is not None: + layers[ds_layer].add(ds_name) - ds_config = _resolve_credentials(ds_config, credentials) - data_sets[ds_name] = AbstractDataSet.from_config( - ds_name, ds_config, load_versions.get(ds_name), save_version - ) + ds_config = _resolve_credentials(ds_config, credentials) + data_sets[ds_name] = AbstractDataSet.from_config( + ds_name, ds_config, load_versions.get(ds_name), save_version + ) dataset_layers = layers or None - return cls(data_sets=data_sets, layers=dataset_layers) + return cls( + data_sets=data_sets, + layers=dataset_layers, + dataset_patterns=dataset_patterns, + ) def _get_dataset( self, data_set_name: str, version: Version = None, suggest: bool = True @@ -594,3 +635,44 @@ def confirm(self, name: str) -> None: data_set.confirm() # type: ignore else: raise DataSetError(f"DataSet '{name}' does not have 'confirm' method") + + def resolve(self, named_datasets: set): + """Resolve the set of datasets used by a pipeline with a pattern if it exists + + Args: + named_datasets: A set of datasets used by the pipeline being run + """ + + existing_datasets = self.list() + for dataset in named_datasets: + if dataset in existing_datasets: + continue + matched_dataset = self._match_against_patterns(dataset) + if matched_dataset: + self.add(dataset, matched_dataset) + + def _match_against_patterns(self, dataset_name: str) -> AbstractDataSet | None: + """Match a dataset name against the patterns in the catalog + Args: + dataset_name: Name of the dataset to be matched against a specific pattern + + Returns: + The dataset instance if the pattern is a match, None otherwise + + """ + for pattern in self.sorted_patterns: + result = parse(pattern, dataset_name) + if result: + # Since the patterns are sorted, the first match is the best match + template_copy = copy.deepcopy(self.dataset_patterns[pattern]) + for key, value in template_copy.items(): + string_value = str(value) + try: + formatted_string = string_value.format_map(result.named) + except KeyError as exc: + raise DataSetError( + f"Unable to resolve '{key}' for the pattern '{pattern}'" + ) from exc + template_copy[key] = formatted_string + return AbstractDataSet.from_config(dataset_name, template_copy) + return None diff --git a/tests/io/test_data_catalog.py b/tests/io/test_data_catalog.py index 89f50d1ea5..af3d378b5e 100644 --- a/tests/io/test_data_catalog.py +++ b/tests/io/test_data_catalog.py @@ -87,6 +87,23 @@ def sane_config_with_tracking_ds(tmp_path): } +@pytest.fixture +def config_with_dataset_factories(): + return { + "catalog": { + "{brand}_cars": { + "type": "pandas.CSVDataSet", + "filepath": "data/01_raw/{brand}_cars.csv", + }, + "audi_cars": { + "type": "pandas.ParquetDataSet", + "filepath": "data/01_raw/audi_cars.xls", + }, + "boats": {"type": "pandas.CSVDataSet", "filepath": "data/01_raw/boats.csv"}, + }, + } + + @pytest.fixture def data_set(filepath): return CSVDataSet(filepath=filepath, save_args={"index": False}) @@ -683,3 +700,58 @@ def test_no_versions_with_cloud_protocol(self): ) with pytest.raises(DataSetError, match=pattern): versioned_dataset.load() + + +class TestDataCatalogDatasetFactories: + def test_patterns_not_in_catalog_datasets(self, config_with_dataset_factories): + """Check that the pattern is not in the catalog datasets""" + catalog = DataCatalog.from_config(**config_with_dataset_factories) + assert "audi_cars" in catalog._data_sets + assert "{brand}_cars" not in catalog._data_sets + assert "audi_cars" not in catalog.dataset_patterns + assert "{brand}_cars" in catalog.dataset_patterns + + def test_pattern_matching_only_one_match(self, config_with_dataset_factories): + """Check that dataset names are added to the catalog when one pattern exists""" + catalog = DataCatalog.from_config(**config_with_dataset_factories) + named_datasets = {"tesla_cars", "audi_cars", "ford_cars", "boats"} + # Before resolution + assert "tesla_cars" not in catalog._data_sets + assert "audi_cars" in catalog._data_sets + assert "ford_cars" not in catalog._data_sets + assert "boats" in catalog._data_sets + catalog.resolve(named_datasets) + # After resolution + assert "tesla_cars" in catalog._data_sets + assert "audi_cars" in catalog._data_sets + assert "ford_cars" in catalog._data_sets + assert "boats" in catalog._data_sets + + def test_explicit_entry_not_overwritten(self, config_with_dataset_factories): + """Check that the existing catalog entry is not overwritten by config in pattern""" + catalog = DataCatalog.from_config(**config_with_dataset_factories) + named_datasets = {"tesla_cars", "audi_cars", "ford_cars", "boats"} + audi_cars_before = catalog._get_dataset("audi_cars") + catalog.resolve(named_datasets) + audi_cars_after = catalog._get_dataset("audi_cars") + assert audi_cars_before == audi_cars_after + + def test_dataset_not_in_catalog_when_no_pattern_match(self): + """Check that the dataset is not added to the catalog when there is no pattern""" + assert True + + def test_dataset_pattern_ordering(self): + """Check that the patterns are ordered correctly according to the parsing rules""" + assert True + + def test_pattern_matching_multiple_patterns(self): + """Check that the patterns are matched correctly when multiple patterns exist""" + assert True + + def test_config_parsed_from_pattern(self): + """Check that the body of the dataset entry is correctly parsed""" + assert True + + def test_unmatched_key_error_when_parsing_config(self): + """Check error raised when key mentioned in the config is not in pattern name""" + assert True