Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Dataset factories - Eager resolving approach #2632

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dependency/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ importlib_resources>=1.3 # The `files()` API was introduced in `importlib_resou
jmespath>=0.9.5, <1.0
more_itertools~=9.0
omegaconf~=2.3
parse~=1.19.0
pip-tools~=6.5
pluggy~=1.0.0
PyYAML>=4.2, <7.0
Expand Down
9 changes: 4 additions & 5 deletions kedro/framework/cli/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,22 @@ def list_datasets(metadata: ProjectMetadata, pipeline, env):
mentioned = "Datasets mentioned in pipeline"

session = _create_session(metadata.package_name, env=env)
context = session.load_context()
datasets_meta = context.catalog._data_sets # pylint: disable=protected-access
catalog_ds = set(context.catalog.list())

loaded_catalog = session.load_context().catalog
target_pipelines = pipeline or pipelines.keys()

result = {}
for pipe in target_pipelines:
pl_obj = pipelines.get(pipe)
if pl_obj:
pipeline_ds = pl_obj.data_sets()
loaded_catalog.resolve(pipeline_ds)
else:
existing_pls = ", ".join(sorted(pipelines.keys()))
raise KedroCliError(
f"'{pipe}' pipeline not found! Existing pipelines: {existing_pls}"
)

catalog_ds = set(loaded_catalog.list())
datasets_meta = loaded_catalog._data_sets # pylint: disable=protected-access
unused_ds = catalog_ds - pipeline_ds
default_ds = pipeline_ds - catalog_ds
used_ds = catalog_ds - unused_ds
Expand Down
5 changes: 5 additions & 0 deletions kedro/framework/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,11 @@ def run( # pylint: disable=too-many-arguments,too-many-locals
load_versions=load_versions,
)

# Get a set of all datasets used by the pipeline
named_datasets = filtered_pipeline.data_sets()
# Resolve the datasets against patterns in the catalog
catalog.resolve(named_datasets)

# Run the runner
hook_manager = self._hook_manager
runner = runner or SequentialRunner()
Expand Down
98 changes: 90 additions & 8 deletions kedro/io/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from collections import defaultdict
from typing import Any

from parse import parse

from kedro.io.core import (
AbstractDataSet,
AbstractVersionedDataSet,
Expand Down Expand Up @@ -94,6 +96,26 @@ def _sub_nonword_chars(data_set_name: str) -> str:
return re.sub(WORDS_REGEX_PATTERN, "__", data_set_name)


def _specificity(pattern: str) -> int:
"""Helper function to check length of exactly matched characters not inside brackets
Example -
specificity("{namespace}.companies") = 10
specificity("{namespace}.{dataset}") = 1
specificity("france.companies") = 16

Args:
pattern:

Returns:

"""
pattern_variables = parse(pattern, pattern).named
for k in pattern_variables:
pattern_variables[k] = ""
specific_characters = pattern.format(**pattern_variables)
return -len(specific_characters)


class _FrozenDatasets:
"""Helper class to access underlying loaded datasets"""

Expand Down Expand Up @@ -141,6 +163,7 @@ def __init__(
data_sets: dict[str, AbstractDataSet] = None,
feed_dict: dict[str, Any] = None,
layers: dict[str, set[str]] = None,
dataset_patterns: dict[str, Any] = None,
) -> None:
"""``DataCatalog`` stores instances of ``AbstractDataSet``
implementations to provide ``load`` and ``save`` capabilities from
Expand Down Expand Up @@ -170,6 +193,14 @@ def __init__(
self._data_sets = dict(data_sets or {})
self.datasets = _FrozenDatasets(self._data_sets)
self.layers = layers
# Keep a record of all patterns in the catalog.
# {dataset pattern name : dataset pattern body}
self.dataset_patterns = dict(dataset_patterns or {})
# Sort all patterns according to parsing rules
self.sorted_patterns = sorted(
self.dataset_patterns.keys(),
key=lambda x: (_specificity(x), -x.count("{"), x),
)

# import the feed dict
if feed_dict:
Expand Down Expand Up @@ -257,6 +288,7 @@ class to be loaded is specified with the key ``type`` and their
>>> catalog.save("boats", df)
"""
data_sets = {}
dataset_patterns = {}
catalog = copy.deepcopy(catalog) or {}
credentials = copy.deepcopy(credentials) or {}
save_version = save_version or generate_timestamp()
Expand All @@ -271,17 +303,26 @@ class to be loaded is specified with the key ``type`` and their

layers: dict[str, set[str]] = defaultdict(set)
for ds_name, ds_config in catalog.items():
ds_layer = ds_config.pop("layer", None)
if ds_layer is not None:
layers[ds_layer].add(ds_name)
# Assume that any name with } in it is a dataset factory to be matched.
if "}" in ds_name:
# Add each factory to the dataset_patterns dict.
dataset_patterns[ds_name] = ds_config
else:
ds_layer = ds_config.pop("layer", None)
if ds_layer is not None:
layers[ds_layer].add(ds_name)

ds_config = _resolve_credentials(ds_config, credentials)
data_sets[ds_name] = AbstractDataSet.from_config(
ds_name, ds_config, load_versions.get(ds_name), save_version
)
ds_config = _resolve_credentials(ds_config, credentials)
data_sets[ds_name] = AbstractDataSet.from_config(
ds_name, ds_config, load_versions.get(ds_name), save_version
)

dataset_layers = layers or None
return cls(data_sets=data_sets, layers=dataset_layers)
return cls(
data_sets=data_sets,
layers=dataset_layers,
dataset_patterns=dataset_patterns,
)

def _get_dataset(
self, data_set_name: str, version: Version = None, suggest: bool = True
Expand Down Expand Up @@ -594,3 +635,44 @@ def confirm(self, name: str) -> None:
data_set.confirm() # type: ignore
else:
raise DataSetError(f"DataSet '{name}' does not have 'confirm' method")

def resolve(self, named_datasets: set):
"""Resolve the set of datasets used by a pipeline with a pattern if it exists

Args:
named_datasets: A set of datasets used by the pipeline being run
"""

existing_datasets = self.list()
for dataset in named_datasets:
if dataset in existing_datasets:
continue
matched_dataset = self._match_against_patterns(dataset)
if matched_dataset:
self.add(dataset, matched_dataset)

def _match_against_patterns(self, dataset_name: str) -> AbstractDataSet | None:
"""Match a dataset name against the patterns in the catalog
Args:
dataset_name: Name of the dataset to be matched against a specific pattern

Returns:
The dataset instance if the pattern is a match, None otherwise

"""
for pattern in self.sorted_patterns:
result = parse(pattern, dataset_name)
if result:
# Since the patterns are sorted, the first match is the best match
template_copy = copy.deepcopy(self.dataset_patterns[pattern])
for key, value in template_copy.items():
string_value = str(value)
try:
formatted_string = string_value.format_map(result.named)
except KeyError as exc:
raise DataSetError(
f"Unable to resolve '{key}' for the pattern '{pattern}'"
) from exc
template_copy[key] = formatted_string
return AbstractDataSet.from_config(dataset_name, template_copy)
return None
72 changes: 72 additions & 0 deletions tests/io/test_data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,23 @@ def sane_config_with_tracking_ds(tmp_path):
}


@pytest.fixture
def config_with_dataset_factories():
return {
"catalog": {
"{brand}_cars": {
"type": "pandas.CSVDataSet",
"filepath": "data/01_raw/{brand}_cars.csv",
},
"audi_cars": {
"type": "pandas.ParquetDataSet",
"filepath": "data/01_raw/audi_cars.xls",
},
"boats": {"type": "pandas.CSVDataSet", "filepath": "data/01_raw/boats.csv"},
},
}


@pytest.fixture
def data_set(filepath):
return CSVDataSet(filepath=filepath, save_args={"index": False})
Expand Down Expand Up @@ -683,3 +700,58 @@ def test_no_versions_with_cloud_protocol(self):
)
with pytest.raises(DataSetError, match=pattern):
versioned_dataset.load()


class TestDataCatalogDatasetFactories:
def test_patterns_not_in_catalog_datasets(self, config_with_dataset_factories):
"""Check that the pattern is not in the catalog datasets"""
catalog = DataCatalog.from_config(**config_with_dataset_factories)
assert "audi_cars" in catalog._data_sets
assert "{brand}_cars" not in catalog._data_sets
assert "audi_cars" not in catalog.dataset_patterns
assert "{brand}_cars" in catalog.dataset_patterns

def test_pattern_matching_only_one_match(self, config_with_dataset_factories):
"""Check that dataset names are added to the catalog when one pattern exists"""
catalog = DataCatalog.from_config(**config_with_dataset_factories)
named_datasets = {"tesla_cars", "audi_cars", "ford_cars", "boats"}
# Before resolution
assert "tesla_cars" not in catalog._data_sets
assert "audi_cars" in catalog._data_sets
assert "ford_cars" not in catalog._data_sets
assert "boats" in catalog._data_sets
catalog.resolve(named_datasets)
# After resolution
assert "tesla_cars" in catalog._data_sets
assert "audi_cars" in catalog._data_sets
assert "ford_cars" in catalog._data_sets
assert "boats" in catalog._data_sets

def test_explicit_entry_not_overwritten(self, config_with_dataset_factories):
"""Check that the existing catalog entry is not overwritten by config in pattern"""
catalog = DataCatalog.from_config(**config_with_dataset_factories)
named_datasets = {"tesla_cars", "audi_cars", "ford_cars", "boats"}
audi_cars_before = catalog._get_dataset("audi_cars")
catalog.resolve(named_datasets)
audi_cars_after = catalog._get_dataset("audi_cars")
assert audi_cars_before == audi_cars_after

def test_dataset_not_in_catalog_when_no_pattern_match(self):
"""Check that the dataset is not added to the catalog when there is no pattern"""
assert True

def test_dataset_pattern_ordering(self):
"""Check that the patterns are ordered correctly according to the parsing rules"""
assert True

def test_pattern_matching_multiple_patterns(self):
"""Check that the patterns are matched correctly when multiple patterns exist"""
assert True

def test_config_parsed_from_pattern(self):
"""Check that the body of the dataset entry is correctly parsed"""
assert True

def test_unmatched_key_error_when_parsing_config(self):
"""Check error raised when key mentioned in the config is not in pattern name"""
assert True