From 1c6a88d90e91f6cdd76e4cd0bc94113b8eaa7750 Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Thu, 5 Oct 2023 16:36:12 +0100 Subject: [PATCH 1/9] Add partitioned and incremental datasets to repo Signed-off-by: Merel Theisen --- kedro-datasets/RELEASE.md | 2 + kedro-datasets/docs/source/kedro_datasets.rst | 2 + .../kedro_datasets/partitions/__init__.py | 11 + .../partitions/incremental_dataset.py | 237 ++++++++ .../partitions/partitioned_dataset.py | 329 +++++++++++ kedro-datasets/tests/partitions/__init__.py | 0 .../partitions/test_incremental_dataset.py | 508 ++++++++++++++++ .../partitions/test_partitioned_dataset.py | 540 ++++++++++++++++++ 8 files changed, 1629 insertions(+) create mode 100644 kedro-datasets/kedro_datasets/partitions/__init__.py create mode 100644 kedro-datasets/kedro_datasets/partitions/incremental_dataset.py create mode 100644 kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py create mode 100644 kedro-datasets/tests/partitions/__init__.py create mode 100644 kedro-datasets/tests/partitions/test_incremental_dataset.py create mode 100644 kedro-datasets/tests/partitions/test_partitioned_dataset.py diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index 9c6661fda..168e7d72f 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -1,5 +1,7 @@ # Upcoming Release ## Major features and improvements +* Moved `PartitionedDataSet` and `IncrementalDataSet` from the core Kedro repo to `kedro-datasets`. + ## Bug fixes and other changes ## Upcoming deprecations for Kedro-Datasets 2.0.0 * Renamed dataset and error classes, in accordance with the [Kedro lexicon](https://github.com/kedro-org/kedro/wiki/Kedro-documentation-style-guide#kedro-lexicon). Dataset classes ending with "DataSet" are deprecated and will be removed in 2.0.0. diff --git a/kedro-datasets/docs/source/kedro_datasets.rst b/kedro-datasets/docs/source/kedro_datasets.rst index d8db36ee0..67f87e0e3 100644 --- a/kedro-datasets/docs/source/kedro_datasets.rst +++ b/kedro-datasets/docs/source/kedro_datasets.rst @@ -59,6 +59,8 @@ kedro_datasets kedro_datasets.pandas.SQLTableDataset kedro_datasets.pandas.XMLDataSet kedro_datasets.pandas.XMLDataset + kedro_datasets.partitions.IncrementalDataset + kedro_datasets.partitions.PartitionedDataset kedro_datasets.pickle.PickleDataSet kedro_datasets.pickle.PickleDataset kedro_datasets.pillow.ImageDataSet diff --git a/kedro-datasets/kedro_datasets/partitions/__init__.py b/kedro-datasets/kedro_datasets/partitions/__init__.py new file mode 100644 index 000000000..2f464a907 --- /dev/null +++ b/kedro-datasets/kedro_datasets/partitions/__init__.py @@ -0,0 +1,11 @@ +"""``AbstractDataset`` implementation to load/save data in partitions +from/to any underlying Dataset format. +""" + +__all__ = ["PartitionedDataset", "IncrementalDataset"] + +from contextlib import suppress + +with suppress(ImportError): + from .incremental_dataset import IncrementalDataset + from .partitioned_dataset import PartitionedDataset diff --git a/kedro-datasets/kedro_datasets/partitions/incremental_dataset.py b/kedro-datasets/kedro_datasets/partitions/incremental_dataset.py new file mode 100644 index 000000000..9623a5893 --- /dev/null +++ b/kedro-datasets/kedro_datasets/partitions/incremental_dataset.py @@ -0,0 +1,237 @@ +"""``IncrementalDataset`` inherits from ``PartitionedDataset``, which loads +and saves partitioned file-like data using the underlying dataset +definition. ``IncrementalDataset`` also stores the information about the last +processed partition in so-called `checkpoint` that is persisted to the location +of the data partitions by default, so that subsequent pipeline run loads only +new partitions past the checkpoint.It also uses `fsspec` for filesystem level operations. +""" +from __future__ import annotations + +import operator +from copy import deepcopy +from typing import Any, Callable + +from cachetools import cachedmethod +from kedro.io.core import ( + VERSION_KEY, + VERSIONED_FLAG_KEY, + AbstractDataset, + DatasetError, + parse_dataset_definition, +) +from kedro.io.data_catalog import CREDENTIALS_KEY +from kedro.utils import load_obj + +from .partitioned_dataset import KEY_PROPAGATION_WARNING, PartitionedDataset + + +class IncrementalDataset(PartitionedDataset): + """``IncrementalDataset`` inherits from ``PartitionedDataset``, which loads + and saves partitioned file-like data using the underlying dataset + definition. For filesystem level operations it uses `fsspec`: + https://github.com/intake/filesystem_spec. ``IncrementalDataset`` also stores + the information about the last processed partition in so-called `checkpoint` + that is persisted to the location of the data partitions by default, so that + subsequent pipeline run loads only new partitions past the checkpoint. + + Example: + :: + + >>> from kedro_datasets.partitions import IncrementalDataset + >>> + >>> # these credentials will be passed to: + >>> # a) 'fsspec.filesystem()' call, + >>> # b) the dataset initializer, + >>> # c) the checkpoint initializer + >>> credentials = {"key1": "secret1", "key2": "secret2"} + >>> + >>> data_set = IncrementalDataset( + >>> path="s3://bucket-name/path/to/folder", + >>> dataset="pandas.CSVDataset", + >>> credentials=credentials + >>> ) + >>> loaded = data_set.load() # loads all available partitions + >>> # assert isinstance(loaded, dict) + >>> + >>> data_set.confirm() # update checkpoint value to the last processed partition ID + >>> reloaded = data_set.load() # still loads all available partitions + >>> + >>> data_set.release() # clears load cache + >>> # returns an empty dictionary as no new partitions were added + >>> data_set.load() + """ + + DEFAULT_CHECKPOINT_TYPE = "kedro_datasets.text.TextDataSet" + DEFAULT_CHECKPOINT_FILENAME = "CHECKPOINT" + + def __init__( # noqa: PLR0913 + self, + path: str, + dataset: str | type[AbstractDataset] | dict[str, Any], + checkpoint: str | dict[str, Any] | None = None, + filepath_arg: str = "filepath", + filename_suffix: str = "", + credentials: dict[str, Any] = None, + load_args: dict[str, Any] = None, + fs_args: dict[str, Any] = None, + metadata: dict[str, Any] = None, + ) -> None: + """Creates a new instance of ``IncrementalDataset``. + + Args: + path: Path to the folder containing partitioned data. + If path starts with the protocol (e.g., ``s3://``) then the + corresponding ``fsspec`` concrete filesystem implementation will + be used. If protocol is not specified, + ``fsspec.implementations.local.LocalFileSystem`` will be used. + **Note:** Some concrete implementations are bundled with ``fsspec``, + while others (like ``s3`` or ``gcs``) must be installed separately + prior to usage of the ``PartitionedDataset``. + dataset: Underlying dataset definition. This is used to instantiate + the dataset for each file located inside the ``path``. + Accepted formats are: + a) object of a class that inherits from ``AbstractDataset`` + b) a string representing a fully qualified class name to such class + c) a dictionary with ``type`` key pointing to a string from b), + other keys are passed to the Dataset initializer. + Credentials for the dataset can be explicitly specified in + this configuration. + checkpoint: Optional checkpoint configuration. Accepts a dictionary + with the corresponding dataset definition including ``filepath`` + (unlike ``dataset`` argument). Checkpoint configuration is + described here: + https://kedro.readthedocs.io/en/stable/data/kedro_io.html#checkpoint-configuration + Credentials for the checkpoint can be explicitly specified + in this configuration. + filepath_arg: Underlying dataset initializer argument that will + contain a path to each corresponding partition file. + If unspecified, defaults to "filepath". + filename_suffix: If specified, only partitions that end with this + string will be processed. + credentials: Protocol-specific options that will be passed to + ``fsspec.filesystem`` + https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.filesystem, + the dataset dataset initializer and the checkpoint. If + the dataset or the checkpoint configuration contains explicit + credentials spec, then such spec will take precedence. + All possible credentials management scenarios are documented here: + https://kedro.readthedocs.io/en/stable/data/kedro_io.html#partitioned-dataset-credentials + load_args: Keyword arguments to be passed into ``find()`` method of + the filesystem implementation. + fs_args: Extra arguments to pass into underlying filesystem class constructor + (e.g. `{"project": "my-project"}` for ``GCSFileSystem``). + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. + + Raises: + DatasetError: If versioning is enabled for the underlying dataset. + """ + + super().__init__( + path=path, + dataset=dataset, + filepath_arg=filepath_arg, + filename_suffix=filename_suffix, + credentials=credentials, + load_args=load_args, + fs_args=fs_args, + ) + + self._checkpoint_config = self._parse_checkpoint_config(checkpoint) + self._force_checkpoint = self._checkpoint_config.pop("force_checkpoint", None) + self.metadata = metadata + + comparison_func = self._checkpoint_config.pop("comparison_func", operator.gt) + if isinstance(comparison_func, str): + comparison_func = load_obj(comparison_func) + self._comparison_func = comparison_func + + def _parse_checkpoint_config( + self, checkpoint_config: str | dict[str, Any] | None + ) -> dict[str, Any]: + checkpoint_config = deepcopy(checkpoint_config) + if isinstance(checkpoint_config, str): + checkpoint_config = {"force_checkpoint": checkpoint_config} + checkpoint_config = checkpoint_config or {} + + for key in {VERSION_KEY, VERSIONED_FLAG_KEY} & checkpoint_config.keys(): + raise DatasetError( + f"'{self.__class__.__name__}' does not support versioning of the " + f"checkpoint. Please remove '{key}' key from the checkpoint definition." + ) + + default_checkpoint_path = self._sep.join( + [self._normalized_path.rstrip(self._sep), self.DEFAULT_CHECKPOINT_FILENAME] + ) + default_config = { + "type": self.DEFAULT_CHECKPOINT_TYPE, + self._filepath_arg: default_checkpoint_path, + } + if self._credentials: + default_config[CREDENTIALS_KEY] = deepcopy(self._credentials) + + if CREDENTIALS_KEY in default_config.keys() & checkpoint_config.keys(): + self._logger.warning( + KEY_PROPAGATION_WARNING, + {"keys": CREDENTIALS_KEY, "target": "checkpoint"}, + ) + + return {**default_config, **checkpoint_config} + + @cachedmethod(cache=operator.attrgetter("_partition_cache")) + def _list_partitions(self) -> list[str]: + checkpoint = self._read_checkpoint() + checkpoint_path = self._filesystem._strip_protocol( + self._checkpoint_config[self._filepath_arg] + ) + + def _is_valid_partition(partition) -> bool: + if not partition.endswith(self._filename_suffix): + return False + if partition == checkpoint_path: + return False + if checkpoint is None: + # nothing was processed yet + return True + partition_id = self._path_to_partition(partition) + return self._comparison_func(partition_id, checkpoint) + + return sorted( + part + for part in self._filesystem.find(self._normalized_path, **self._load_args) + if _is_valid_partition(part) + ) + + @property + def _checkpoint(self) -> AbstractDataset: + type_, kwargs = parse_dataset_definition(self._checkpoint_config) + return type_(**kwargs) # type: ignore + + def _read_checkpoint(self) -> str | None: + if self._force_checkpoint is not None: + return self._force_checkpoint + try: + return self._checkpoint.load() + except DatasetError: + return None + + def _load(self) -> dict[str, Callable[[], Any]]: + partitions: dict[str, Any] = {} + + for partition in self._list_partitions(): + partition_id = self._path_to_partition(partition) + kwargs = deepcopy(self._dataset_config) + # join the protocol back since PySpark may rely on it + kwargs[self._filepath_arg] = self._join_protocol(partition) + partitions[partition_id] = self._dataset_type( # type: ignore + **kwargs + ).load() + + return partitions + + def confirm(self) -> None: + """Confirm the dataset by updating the checkpoint value to the latest + processed partition ID""" + partition_ids = [self._path_to_partition(p) for p in self._list_partitions()] + if partition_ids: + self._checkpoint.save(partition_ids[-1]) # checkpoint to last partition diff --git a/kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py b/kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py new file mode 100644 index 000000000..74242b113 --- /dev/null +++ b/kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py @@ -0,0 +1,329 @@ +"""``PartitionedDataset`` loads and saves partitioned file-like data using the +underlying dataset definition. It also uses `fsspec` for filesystem level operations. +""" +from __future__ import annotations + +import operator +from copy import deepcopy +from typing import Any, Callable, Dict +from urllib.parse import urlparse +from warnings import warn + +import fsspec +from cachetools import Cache, cachedmethod +from kedro.io.core import ( + VERSION_KEY, + VERSIONED_FLAG_KEY, + AbstractDataset, + DatasetError, + parse_dataset_definition, +) +from kedro.io.data_catalog import CREDENTIALS_KEY + +KEY_PROPAGATION_WARNING = ( + "Top-level %(keys)s will not propagate into the %(target)s since " + "%(keys)s were explicitly defined in the %(target)s config." +) + +S3_PROTOCOLS = ("s3", "s3a", "s3n") + + +class PartitionedDataset(AbstractDataset[Dict[str, Any], Dict[str, Callable[[], Any]]]): + """``PartitionedDataset`` loads and saves partitioned file-like data using the + underlying dataset definition. For filesystem level operations it uses `fsspec`: + https://github.com/intake/filesystem_spec. + + It also supports advanced features like + `lazy saving `_. + + Example usage for the + `YAML API `_: + + .. code-block:: yaml + + station_data: + type: PartitionedDataset + path: data/03_primary/station_data + dataset: + type: pandas.CSVDataset + load_args: + sep: '\\t' + save_args: + sep: '\\t' + index: true + filename_suffix: '.dat' + + Example usage for the + `Python API `_: + :: + + >>> import pandas as pd + >>> from kedro_datasets.partitions import PartitionedDataset + >>> + >>> # Create a fake pandas dataframe with 10 rows of data + >>> df = pd.DataFrame([{"DAY_OF_MONTH": str(i), "VALUE": i} for i in range(1, 11)]) + >>> + >>> # Convert it to a dict of pd.DataFrame with DAY_OF_MONTH as the dict key + >>> dict_df = { + day_of_month: df[df["DAY_OF_MONTH"] == day_of_month] + for day_of_month in df["DAY_OF_MONTH"] + } + >>> + >>> # Save it as small paritions with DAY_OF_MONTH as the partition key + >>> data_set = PartitionedDataset( + path="df_with_partition", + dataset="pandas.CSVDataset", + filename_suffix=".csv" + ) + >>> # This will create a folder `df_with_partition` and save multiple files + >>> # with the dict key + filename_suffix as filename, i.e. 1.csv, 2.csv etc. + >>> data_set.save(dict_df) + >>> + >>> # This will create lazy load functions instead of loading data into memory immediately. + >>> loaded = data_set.load() + >>> + >>> # Load all the partitions + >>> for partition_id, partition_load_func in loaded.items(): + # The actual function that loads the data + partition_data = partition_load_func() + >>> + >>> # Add the processing logic for individual partition HERE + >>> print(partition_data) + + You can also load multiple partitions from a remote storage and combine them + like this: + :: + + >>> import pandas as pd + >>> from kedro_datasets.partitions import PartitionedDataset + >>> + >>> # these credentials will be passed to both 'fsspec.filesystem()' call + >>> # and the dataset initializer + >>> credentials = {"key1": "secret1", "key2": "secret2"} + >>> + >>> data_set = PartitionedDataset( + path="s3://bucket-name/path/to/folder", + dataset="pandas.CSVDataset", + credentials=credentials + ) + >>> loaded = data_set.load() + >>> # assert isinstance(loaded, dict) + >>> + >>> combine_all = pd.DataFrame() + >>> + >>> for partition_id, partition_load_func in loaded.items(): + partition_data = partition_load_func() + combine_all = pd.concat( + [combine_all, partition_data], ignore_index=True, sort=True + ) + >>> + >>> new_data = pd.DataFrame({"new": [1, 2]}) + >>> # creates "s3://bucket-name/path/to/folder/new/partition.csv" + >>> data_set.save({"new/partition.csv": new_data}) + + """ + + def __init__( # noqa: PLR0913 + self, + path: str, + dataset: str | type[AbstractDataset] | dict[str, Any], + filepath_arg: str = "filepath", + filename_suffix: str = "", + credentials: dict[str, Any] = None, + load_args: dict[str, Any] = None, + fs_args: dict[str, Any] = None, + overwrite: bool = False, + metadata: dict[str, Any] = None, + ) -> None: + """Creates a new instance of ``PartitionedDataset``. + + Args: + path: Path to the folder containing partitioned data. + If path starts with the protocol (e.g., ``s3://``) then the + corresponding ``fsspec`` concrete filesystem implementation will + be used. If protocol is not specified, + ``fsspec.implementations.local.LocalFileSystem`` will be used. + **Note:** Some concrete implementations are bundled with ``fsspec``, + while others (like ``s3`` or ``gcs``) must be installed separately + prior to usage of the ``PartitionedDataset``. + dataset: Underlying dataset definition. This is used to instantiate + the dataset for each file located inside the ``path``. + Accepted formats are: + a) object of a class that inherits from ``AbstractDataset`` + b) a string representing a fully qualified class name to such class + c) a dictionary with ``type`` key pointing to a string from b), + other keys are passed to the Dataset initializer. + Credentials for the dataset can be explicitly specified in + this configuration. + filepath_arg: Underlying dataset initializer argument that will + contain a path to each corresponding partition file. + If unspecified, defaults to "filepath". + filename_suffix: If specified, only partitions that end with this + string will be processed. + credentials: Protocol-specific options that will be passed to + ``fsspec.filesystem`` + https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.filesystem + and the dataset initializer. If the dataset config contains + explicit credentials spec, then such spec will take precedence. + All possible credentials management scenarios are documented here: + https://kedro.readthedocs.io/en/stable/data/kedro_io.html#partitioned-dataset-credentials + load_args: Keyword arguments to be passed into ``find()`` method of + the filesystem implementation. + fs_args: Extra arguments to pass into underlying filesystem class constructor + (e.g. `{"project": "my-project"}` for ``GCSFileSystem``) + overwrite: If True, any existing partitions will be removed. + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. + + Raises: + DatasetError: If versioning is enabled for the underlying dataset. + """ + from fsspec.utils import infer_storage_options # for performance reasons + + super().__init__() + + self._path = path + self._filename_suffix = filename_suffix + self._overwrite = overwrite + self._protocol = infer_storage_options(self._path)["protocol"] + self._partition_cache: Cache = Cache(maxsize=1) + self.metadata = metadata + + dataset = dataset if isinstance(dataset, dict) else {"type": dataset} + self._dataset_type, self._dataset_config = parse_dataset_definition(dataset) + if VERSION_KEY in self._dataset_config: + raise DatasetError( + f"'{self.__class__.__name__}' does not support versioning of the " + f"underlying dataset. Please remove '{VERSIONED_FLAG_KEY}' flag from " + f"the dataset definition." + ) + + if credentials: + if CREDENTIALS_KEY in self._dataset_config: + self._logger.warning( + KEY_PROPAGATION_WARNING, + {"keys": CREDENTIALS_KEY, "target": "underlying dataset"}, + ) + else: + self._dataset_config[CREDENTIALS_KEY] = deepcopy(credentials) + + self._credentials = deepcopy(credentials) or {} + + self._fs_args = deepcopy(fs_args) or {} + if self._fs_args: + if "fs_args" in self._dataset_config: + self._logger.warning( + KEY_PROPAGATION_WARNING, + {"keys": "filesystem arguments", "target": "underlying dataset"}, + ) + else: + self._dataset_config["fs_args"] = deepcopy(self._fs_args) + + self._filepath_arg = filepath_arg + if self._filepath_arg in self._dataset_config: + warn( + f"'{self._filepath_arg}' key must not be specified in the dataset " + f"definition as it will be overwritten by partition path" + ) + + self._load_args = deepcopy(load_args) or {} + self._sep = self._filesystem.sep + # since some filesystem implementations may implement a global cache + self._invalidate_caches() + + @property + def _filesystem(self): + protocol = "s3" if self._protocol in S3_PROTOCOLS else self._protocol + return fsspec.filesystem(protocol, **self._credentials, **self._fs_args) + + @property + def _normalized_path(self) -> str: + if self._protocol in S3_PROTOCOLS: + return urlparse(self._path)._replace(scheme="s3").geturl() + return self._path + + @cachedmethod(cache=operator.attrgetter("_partition_cache")) + def _list_partitions(self) -> list[str]: + return [ + path + for path in self._filesystem.find(self._normalized_path, **self._load_args) + if path.endswith(self._filename_suffix) + ] + + def _join_protocol(self, path: str) -> str: + protocol_prefix = f"{self._protocol}://" + if self._path.startswith(protocol_prefix) and not path.startswith( + protocol_prefix + ): + return f"{protocol_prefix}{path}" + return path + + def _partition_to_path(self, path: str): + dir_path = self._path.rstrip(self._sep) + path = path.lstrip(self._sep) + full_path = self._sep.join([dir_path, path]) + self._filename_suffix + return full_path + + def _path_to_partition(self, path: str) -> str: + dir_path = self._filesystem._strip_protocol(self._normalized_path) + path = path.split(dir_path, 1).pop().lstrip(self._sep) + if self._filename_suffix and path.endswith(self._filename_suffix): + path = path[: -len(self._filename_suffix)] + return path + + def _load(self) -> dict[str, Callable[[], Any]]: + partitions = {} + + for partition in self._list_partitions(): + kwargs = deepcopy(self._dataset_config) + # join the protocol back since PySpark may rely on it + kwargs[self._filepath_arg] = self._join_protocol(partition) + dataset = self._dataset_type(**kwargs) # type: ignore + partition_id = self._path_to_partition(partition) + partitions[partition_id] = dataset.load + + if not partitions: + raise DatasetError(f"No partitions found in '{self._path}'") + + return partitions + + def _save(self, data: dict[str, Any]) -> None: + if self._overwrite and self._filesystem.exists(self._normalized_path): + self._filesystem.rm(self._normalized_path, recursive=True) + + for partition_id, partition_data in sorted(data.items()): + kwargs = deepcopy(self._dataset_config) + partition = self._partition_to_path(partition_id) + # join the protocol back since tools like PySpark may rely on it + kwargs[self._filepath_arg] = self._join_protocol(partition) + dataset = self._dataset_type(**kwargs) # type: ignore + if callable(partition_data): + partition_data = partition_data() # noqa: PLW2901 + dataset.save(partition_data) + self._invalidate_caches() + + def _describe(self) -> dict[str, Any]: + clean_dataset_config = ( + {k: v for k, v in self._dataset_config.items() if k != CREDENTIALS_KEY} + if isinstance(self._dataset_config, dict) + else self._dataset_config + ) + return { + "path": self._path, + "dataset_type": self._dataset_type.__name__, + "dataset_config": clean_dataset_config, + } + + def _invalidate_caches(self) -> None: + self._partition_cache.clear() + self._filesystem.invalidate_cache(self._normalized_path) + + def _exists(self) -> bool: + return bool(self._list_partitions()) + + def _release(self) -> None: + super()._release() + self._invalidate_caches() diff --git a/kedro-datasets/tests/partitions/__init__.py b/kedro-datasets/tests/partitions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/tests/partitions/test_incremental_dataset.py b/kedro-datasets/tests/partitions/test_incremental_dataset.py new file mode 100644 index 000000000..7c880f88d --- /dev/null +++ b/kedro-datasets/tests/partitions/test_incremental_dataset.py @@ -0,0 +1,508 @@ +from __future__ import annotations + +import os +import re +from pathlib import Path +from typing import Any + +import boto3 +import pandas as pd +import pytest +from kedro.io.core import AbstractDataset, DatasetError +from kedro.io.data_catalog import CREDENTIALS_KEY +from moto import mock_s3 +from pandas.util.testing import assert_frame_equal + +from kedro_datasets.partitions import IncrementalDataset +from kedro_datasets.pickle import PickleDataSet +from kedro_datasets.text import TextDataSet + +DATASET = "kedro_datasets.pandas.csv_dataset.CSVDataSet" + + +@pytest.fixture +def partitioned_data_pandas(): + return { + f"p{counter:02d}/data.csv": pd.DataFrame( + {"part": counter, "col": list(range(counter + 1))} + ) + for counter in range(5) + } + + +@pytest.fixture +def local_csvs(tmp_path, partitioned_data_pandas): + local_dir = Path(tmp_path / "csvs") + local_dir.mkdir() + + for k, data in partitioned_data_pandas.items(): + path = local_dir / k + path.parent.mkdir(parents=True) + data.to_csv(str(path), index=False) + return local_dir + + +class DummyDataset(AbstractDataset): # pragma: no cover + def __init__(self, filepath): + pass + + def _describe(self) -> dict[str, Any]: + return {"dummy": True} + + def _load(self) -> Any: + pass + + def _save(self, data: Any) -> None: + pass + + +def dummy_gt_func(value1: str, value2: str): + return value1 > value2 + + +def dummy_lt_func(value1: str, value2: str): + return value1 < value2 + + +class TestIncrementalDatasetLocal: + def test_load_and_confirm(self, local_csvs, partitioned_data_pandas): + """Test the standard flow for loading, confirming and reloading + an IncrementalDataset""" + pds = IncrementalDataset(str(local_csvs), DATASET) + loaded = pds.load() + assert loaded.keys() == partitioned_data_pandas.keys() + for partition_id, data in loaded.items(): + assert_frame_equal(data, partitioned_data_pandas[partition_id]) + + checkpoint_path = local_csvs / pds.DEFAULT_CHECKPOINT_FILENAME + assert not checkpoint_path.exists() + pds.confirm() + assert checkpoint_path.is_file() + assert checkpoint_path.read_text() == pds._read_checkpoint() == "p04/data.csv" + + reloaded = pds.load() + assert reloaded.keys() == loaded.keys() + + pds.release() + reloaded_after_release = pds.load() + assert not reloaded_after_release + + def test_save(self, local_csvs): + """Test saving a new partition into an IncrementalDataset""" + df = pd.DataFrame({"dummy": [1, 2, 3]}) + new_partition_key = "p05/data.csv" + new_partition_path = local_csvs / new_partition_key + pds = IncrementalDataset(str(local_csvs), DATASET) + + assert not new_partition_path.exists() + assert new_partition_key not in pds.load() + + pds.save({new_partition_key: df}) + assert new_partition_path.exists() + loaded = pds.load() + assert_frame_equal(loaded[new_partition_key], df) + + @pytest.mark.parametrize( + "filename_suffix,expected_partitions", + [ + ( + "", + { + "p00/data.csv", + "p01/data.csv", + "p02/data.csv", + "p03/data.csv", + "p04/data.csv", + }, + ), + (".csv", {"p00/data", "p01/data", "p02/data", "p03/data", "p04/data"}), + (".fake", set()), + ], + ) + def test_filename_suffix(self, filename_suffix, expected_partitions, local_csvs): + """Test how specifying filename_suffix affects the available + partitions and their names""" + pds = IncrementalDataset( + str(local_csvs), DATASET, filename_suffix=filename_suffix + ) + loaded = pds.load() + assert loaded.keys() == expected_partitions + + @pytest.mark.parametrize( + "forced_checkpoint,expected_partitions", + [ + ( + "", + { + "p00/data.csv", + "p01/data.csv", + "p02/data.csv", + "p03/data.csv", + "p04/data.csv", + }, + ), + ( + "p00/data.csv", + {"p01/data.csv", "p02/data.csv", "p03/data.csv", "p04/data.csv"}, + ), + ("p03/data.csv", {"p04/data.csv"}), + ], + ) + def test_force_checkpoint_no_checkpoint_file( + self, forced_checkpoint, expected_partitions, local_csvs + ): + """Test how forcing checkpoint value affects the available partitions + if the checkpoint file does not exist""" + pds = IncrementalDataset(str(local_csvs), DATASET, checkpoint=forced_checkpoint) + loaded = pds.load() + assert loaded.keys() == expected_partitions + + confirm_path = local_csvs / pds.DEFAULT_CHECKPOINT_FILENAME + assert not confirm_path.exists() + pds.confirm() + assert confirm_path.is_file() + assert confirm_path.read_text() == max(expected_partitions) + + @pytest.mark.parametrize( + "forced_checkpoint,expected_partitions", + [ + ( + "", + { + "p00/data.csv", + "p01/data.csv", + "p02/data.csv", + "p03/data.csv", + "p04/data.csv", + }, + ), + ( + "p00/data.csv", + {"p01/data.csv", "p02/data.csv", "p03/data.csv", "p04/data.csv"}, + ), + ("p03/data.csv", {"p04/data.csv"}), + ], + ) + def test_force_checkpoint_checkpoint_file_exists( + self, forced_checkpoint, expected_partitions, local_csvs + ): + """Test how forcing checkpoint value affects the available partitions + if the checkpoint file exists""" + IncrementalDataset(str(local_csvs), DATASET).confirm() + checkpoint = local_csvs / IncrementalDataset.DEFAULT_CHECKPOINT_FILENAME + assert checkpoint.read_text() == "p04/data.csv" + + pds = IncrementalDataset(str(local_csvs), DATASET, checkpoint=forced_checkpoint) + assert pds._checkpoint.exists() + loaded = pds.load() + assert loaded.keys() == expected_partitions + + @pytest.mark.parametrize( + "forced_checkpoint", ["p04/data.csv", "p10/data.csv", "p100/data.csv"] + ) + def test_force_checkpoint_no_partitions(self, forced_checkpoint, local_csvs): + """Test that forcing the checkpoint to certain values results in no + partitions being returned""" + pds = IncrementalDataset(str(local_csvs), DATASET, checkpoint=forced_checkpoint) + loaded = pds.load() + assert not loaded + + confirm_path = local_csvs / pds.DEFAULT_CHECKPOINT_FILENAME + assert not confirm_path.exists() + pds.confirm() + # confirming with no partitions available must have no effect + assert not confirm_path.exists() + + def test_checkpoint_path(self, local_csvs, partitioned_data_pandas): + """Test configuring a different checkpoint path""" + checkpoint_path = local_csvs / "checkpoint_folder" / "checkpoint_file" + assert not checkpoint_path.exists() + + IncrementalDataset( + str(local_csvs), DATASET, checkpoint={"filepath": str(checkpoint_path)} + ).confirm() + assert checkpoint_path.is_file() + assert checkpoint_path.read_text() == max(partitioned_data_pandas) + + @pytest.mark.parametrize( + "checkpoint_config,expected_checkpoint_class", + [ + (None, TextDataSet), + ({"type": "kedro_datasets.pickle.PickleDataSet"}, PickleDataSet), + ( + {"type": "tests.partitions.test_incremental_dataset.DummyDataset"}, + DummyDataset, + ), + ], + ) + def test_checkpoint_type( + self, tmp_path, checkpoint_config, expected_checkpoint_class + ): + """Test configuring a different checkpoint dataset type""" + pds = IncrementalDataset(str(tmp_path), DATASET, checkpoint=checkpoint_config) + assert isinstance(pds._checkpoint, expected_checkpoint_class) + + @pytest.mark.parametrize( + "checkpoint_config,error_pattern", + [ + ( + {"versioned": True}, + "'IncrementalDataset' does not support versioning " + "of the checkpoint. Please remove 'versioned' key from the " + "checkpoint definition.", + ), + ( + {"version": None}, + "'IncrementalDataset' does not support versioning " + "of the checkpoint. Please remove 'version' key from the " + "checkpoint definition.", + ), + ], + ) + def test_version_not_allowed(self, tmp_path, checkpoint_config, error_pattern): + """Test that invalid checkpoint configurations raise expected errors""" + with pytest.raises(DatasetError, match=re.escape(error_pattern)): + IncrementalDataset(str(tmp_path), DATASET, checkpoint=checkpoint_config) + + @pytest.mark.parametrize( + "pds_config,fs_creds,dataset_creds,checkpoint_creds", + [ + ( + {"dataset": DATASET, "credentials": {"cred": "common"}}, + {"cred": "common"}, + {"cred": "common"}, + {"cred": "common"}, + ), + ( + { + "dataset": {"type": DATASET, "credentials": {"ds": "only"}}, + "credentials": {"cred": "common"}, + }, + {"cred": "common"}, + {"ds": "only"}, + {"cred": "common"}, + ), + ( + { + "dataset": DATASET, + "credentials": {"cred": "common"}, + "checkpoint": {"credentials": {"cp": "only"}}, + }, + {"cred": "common"}, + {"cred": "common"}, + {"cp": "only"}, + ), + ( + { + "dataset": {"type": DATASET, "credentials": {"ds": "only"}}, + "checkpoint": {"credentials": {"cp": "only"}}, + }, + {}, + {"ds": "only"}, + {"cp": "only"}, + ), + ( + { + "dataset": {"type": DATASET, "credentials": None}, + "credentials": {"cred": "common"}, + "checkpoint": {"credentials": None}, + }, + {"cred": "common"}, + None, + None, + ), + ], + ) + def test_credentials(self, pds_config, fs_creds, dataset_creds, checkpoint_creds): + """Test correctness of credentials propagation into the dataset and + checkpoint constructors""" + pds = IncrementalDataset(str(Path.cwd()), **pds_config) + assert pds._credentials == fs_creds + assert pds._dataset_config[CREDENTIALS_KEY] == dataset_creds + assert pds._checkpoint_config[CREDENTIALS_KEY] == checkpoint_creds + + @pytest.mark.parametrize( + "comparison_func,expected_partitions", + [ + ( + "tests.partitions.test_incremental_dataset.dummy_gt_func", + {"p03/data.csv", "p04/data.csv"}, + ), + (dummy_gt_func, {"p03/data.csv", "p04/data.csv"}), + ( + "tests.partitions.test_incremental_dataset.dummy_lt_func", + {"p00/data.csv", "p01/data.csv"}, + ), + (dummy_lt_func, {"p00/data.csv", "p01/data.csv"}), + ], + ) + def test_comparison_func(self, comparison_func, expected_partitions, local_csvs): + """Test that specifying a custom function for comparing the checkpoint value + to a partition id results in expected partitions being returned on load""" + checkpoint_config = { + "force_checkpoint": "p02/data.csv", + "comparison_func": comparison_func, + } + pds = IncrementalDataset(str(local_csvs), DATASET, checkpoint=checkpoint_config) + assert pds.load().keys() == expected_partitions + + +BUCKET_NAME = "fake_bucket_name" + + +@pytest.fixture +def mocked_s3_bucket(): + """Create a bucket for testing using moto.""" + with mock_s3(): + conn = boto3.client( + "s3", + aws_access_key_id="fake_access_key", + aws_secret_access_key="fake_secret_key", + ) + conn.create_bucket(Bucket=BUCKET_NAME) + yield conn + + +@pytest.fixture +def mocked_csvs_in_s3(mocked_s3_bucket, partitioned_data_pandas): + prefix = "csvs" + for key, data in partitioned_data_pandas.items(): + mocked_s3_bucket.put_object( + Bucket=BUCKET_NAME, + Key=f"{prefix}/{key}", + Body=data.to_csv(index=False), + ) + return f"s3://{BUCKET_NAME}/{prefix}" + + +class TestPartitionedDatasetS3: + os.environ["AWS_ACCESS_KEY_ID"] = "FAKE_ACCESS_KEY" + os.environ["AWS_SECRET_ACCESS_KEY"] = "FAKE_SECRET_KEY" + + def test_load_and_confirm(self, mocked_csvs_in_s3, partitioned_data_pandas): + """Test the standard flow for loading, confirming and reloading + a IncrementalDataset in S3""" + pds = IncrementalDataset(mocked_csvs_in_s3, DATASET) + assert pds._checkpoint._protocol == "s3" + loaded = pds.load() + assert loaded.keys() == partitioned_data_pandas.keys() + for partition_id, data in loaded.items(): + assert_frame_equal(data, partitioned_data_pandas[partition_id]) + + assert not pds._checkpoint.exists() + assert pds._read_checkpoint() is None + pds.confirm() + assert pds._checkpoint.exists() + assert pds._read_checkpoint() == max(partitioned_data_pandas) + + def test_load_and_confirm_s3a( + self, mocked_csvs_in_s3, partitioned_data_pandas, mocker + ): + s3a_path = f"s3a://{mocked_csvs_in_s3.split('://', 1)[1]}" + pds = IncrementalDataset(s3a_path, DATASET) + assert pds._protocol == "s3a" + assert pds._checkpoint._protocol == "s3" + + mocked_ds = mocker.patch.object(pds, "_dataset_type") + mocked_ds.__name__ = "mocked" + loaded = pds.load() + + assert loaded.keys() == partitioned_data_pandas.keys() + assert not pds._checkpoint.exists() + assert pds._read_checkpoint() is None + pds.confirm() + assert pds._checkpoint.exists() + assert pds._read_checkpoint() == max(partitioned_data_pandas) + + @pytest.mark.parametrize( + "forced_checkpoint,expected_partitions", + [ + ( + "", + { + "p00/data.csv", + "p01/data.csv", + "p02/data.csv", + "p03/data.csv", + "p04/data.csv", + }, + ), + ( + "p00/data.csv", + {"p01/data.csv", "p02/data.csv", "p03/data.csv", "p04/data.csv"}, + ), + ("p03/data.csv", {"p04/data.csv"}), + ], + ) + def test_force_checkpoint_no_checkpoint_file( + self, forced_checkpoint, expected_partitions, mocked_csvs_in_s3 + ): + """Test how forcing checkpoint value affects the available partitions + in S3 if the checkpoint file does not exist""" + pds = IncrementalDataset( + mocked_csvs_in_s3, DATASET, checkpoint=forced_checkpoint + ) + loaded = pds.load() + assert loaded.keys() == expected_partitions + + assert not pds._checkpoint.exists() + pds.confirm() + assert pds._checkpoint.exists() + assert pds._checkpoint.load() == max(expected_partitions) + + @pytest.mark.parametrize( + "forced_checkpoint,expected_partitions", + [ + ( + "", + { + "p00/data.csv", + "p01/data.csv", + "p02/data.csv", + "p03/data.csv", + "p04/data.csv", + }, + ), + ( + "p00/data.csv", + {"p01/data.csv", "p02/data.csv", "p03/data.csv", "p04/data.csv"}, + ), + ("p03/data.csv", {"p04/data.csv"}), + ], + ) + def test_force_checkpoint_checkpoint_file_exists( + self, forced_checkpoint, expected_partitions, mocked_csvs_in_s3 + ): + """Test how forcing checkpoint value affects the available partitions + in S3 if the checkpoint file exists""" + # create checkpoint and assert that it exists + IncrementalDataset(mocked_csvs_in_s3, DATASET).confirm() + checkpoint_path = ( + f"{mocked_csvs_in_s3}/{IncrementalDataset.DEFAULT_CHECKPOINT_FILENAME}" + ) + checkpoint_value = TextDataSet(checkpoint_path).load() + assert checkpoint_value == "p04/data.csv" + + pds = IncrementalDataset( + mocked_csvs_in_s3, DATASET, checkpoint=forced_checkpoint + ) + assert pds._checkpoint.exists() + loaded = pds.load() + assert loaded.keys() == expected_partitions + + @pytest.mark.parametrize( + "forced_checkpoint", ["p04/data.csv", "p10/data.csv", "p100/data.csv"] + ) + def test_force_checkpoint_no_partitions(self, forced_checkpoint, mocked_csvs_in_s3): + """Test that forcing the checkpoint to certain values results in no + partitions returned from S3""" + pds = IncrementalDataset( + mocked_csvs_in_s3, DATASET, checkpoint=forced_checkpoint + ) + loaded = pds.load() + assert not loaded + + assert not pds._checkpoint.exists() + pds.confirm() + # confirming with no partitions available must have no effect + assert not pds._checkpoint.exists() diff --git a/kedro-datasets/tests/partitions/test_partitioned_dataset.py b/kedro-datasets/tests/partitions/test_partitioned_dataset.py new file mode 100644 index 000000000..4feb79ac4 --- /dev/null +++ b/kedro-datasets/tests/partitions/test_partitioned_dataset.py @@ -0,0 +1,540 @@ +import logging +import os +import re +from pathlib import Path + +import boto3 +import pandas as pd +import pytest +import s3fs +from kedro.io import DatasetError +from kedro.io.data_catalog import CREDENTIALS_KEY +from moto import mock_s3 +from pandas.util.testing import assert_frame_equal + +from kedro_datasets.pandas import CSVDataset, ParquetDataset +from kedro_datasets.partitions import PartitionedDataset +from kedro_datasets.partitions.partitioned_dataset import KEY_PROPAGATION_WARNING + + +@pytest.fixture +def partitioned_data_pandas(): + keys = ("p1/data1.csv", "p2.csv", "p1/data2.csv", "p3", "_p4") + return { + k: pd.DataFrame({"part": k, "counter": list(range(counter))}) + for counter, k in enumerate(keys, 1) + } + + +@pytest.fixture +def local_csvs(tmp_path, partitioned_data_pandas): + local_dir = Path(str(tmp_path / "csvs")) + local_dir.mkdir() + + for k, data in partitioned_data_pandas.items(): + path = local_dir / k + path.parent.mkdir(parents=True, exist_ok=True) + data.to_csv(str(path), index=False) + return local_dir + + +LOCAL_DATASET_DEFINITION = [ + "pandas.CSVDataset", + "kedro_datasets.pandas.CSVDataset", + CSVDataset, + {"type": "kedro_datasets.pandas.CSVDataset", "save_args": {"index": False}}, + {"type": CSVDataset}, +] + + +class FakeDataset: # pylint: disable=too-few-public-methods + pass + + +class TestPartitionedDatasetLocal: + @pytest.mark.parametrize("dataset", LOCAL_DATASET_DEFINITION) + @pytest.mark.parametrize( + "suffix,expected_num_parts", [("", 5), (".csv", 3), ("p4", 1)] + ) + def test_load( + self, dataset, local_csvs, partitioned_data_pandas, suffix, expected_num_parts + ): + pds = PartitionedDataset(str(local_csvs), dataset, filename_suffix=suffix) + loaded_partitions = pds.load() + + assert len(loaded_partitions.keys()) == expected_num_parts + for partition_id, load_func in loaded_partitions.items(): + df = load_func() + assert_frame_equal(df, partitioned_data_pandas[partition_id + suffix]) + if suffix: + assert not partition_id.endswith(suffix) + + @pytest.mark.parametrize("dataset", LOCAL_DATASET_DEFINITION) + @pytest.mark.parametrize("suffix", ["", ".csv"]) + def test_save(self, dataset, local_csvs, suffix): + pds = PartitionedDataset(str(local_csvs), dataset, filename_suffix=suffix) + original_data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]}) + part_id = "new/data" + pds.save({part_id: original_data}) + + assert (local_csvs / "new" / ("data" + suffix)).is_file() + loaded_partitions = pds.load() + assert part_id in loaded_partitions + reloaded_data = loaded_partitions[part_id]() + assert_frame_equal(reloaded_data, original_data) + + @pytest.mark.parametrize("dataset", LOCAL_DATASET_DEFINITION) + @pytest.mark.parametrize("suffix", ["", ".csv"]) + def test_lazy_save(self, dataset, local_csvs, suffix): + pds = PartitionedDataset(str(local_csvs), dataset, filename_suffix=suffix) + + def original_data(): + return pd.DataFrame({"foo": 42, "bar": ["a", "b", None]}) + + part_id = "new/data" + pds.save({part_id: original_data}) + + assert (local_csvs / "new" / ("data" + suffix)).is_file() + loaded_partitions = pds.load() + assert part_id in loaded_partitions + reloaded_data = loaded_partitions[part_id]() + assert_frame_equal(reloaded_data, original_data()) + + def test_save_invalidates_cache(self, local_csvs, mocker): + """Test that save calls invalidate partition cache""" + pds = PartitionedDataset(str(local_csvs), "pandas.CSVDataset") + mocked_fs_invalidate = mocker.patch.object(pds._filesystem, "invalidate_cache") + first_load = pds.load() + assert pds._partition_cache.currsize == 1 + mocked_fs_invalidate.assert_not_called() + + # save clears cache + data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]}) + new_partition = "new/data.csv" + pds.save({new_partition: data}) + assert pds._partition_cache.currsize == 0 + # it seems that `_filesystem.invalidate_cache` calls itself inside, + # resulting in not one, but 2 mock calls + # hence using `assert_any_call` instead of `assert_called_once_with` + mocked_fs_invalidate.assert_any_call(pds._normalized_path) + + # new load returns new partition too + second_load = pds.load() + assert new_partition not in first_load + assert new_partition in second_load + + @pytest.mark.parametrize("overwrite,expected_num_parts", [(False, 6), (True, 1)]) + def test_overwrite(self, local_csvs, overwrite, expected_num_parts): + pds = PartitionedDataset( + str(local_csvs), "pandas.CSVDataset", overwrite=overwrite + ) + original_data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]}) + part_id = "new/data" + pds.save({part_id: original_data}) + loaded_partitions = pds.load() + + assert part_id in loaded_partitions + assert len(loaded_partitions.keys()) == expected_num_parts + + def test_release_instance_cache(self, local_csvs): + """Test that cache invalidation does not affect other instances""" + ds_a = PartitionedDataset(str(local_csvs), "pandas.CSVDataset") + ds_a.load() + ds_b = PartitionedDataset(str(local_csvs), "pandas.CSVDataset") + ds_b.load() + + assert ds_a._partition_cache.currsize == 1 + assert ds_b._partition_cache.currsize == 1 + + # invalidate cache of the dataset A + ds_a.release() + assert ds_a._partition_cache.currsize == 0 + # cache of the dataset B is unaffected + assert ds_b._partition_cache.currsize == 1 + + @pytest.mark.parametrize("dataset", ["pandas.CSVDataset", "pandas.ParquetDataset"]) + def test_exists(self, local_csvs, dataset): + assert PartitionedDataset(str(local_csvs), dataset).exists() + + empty_folder = local_csvs / "empty" / "folder" + assert not PartitionedDataset(str(empty_folder), dataset).exists() + empty_folder.mkdir(parents=True) + assert not PartitionedDataset(str(empty_folder), dataset).exists() + + @pytest.mark.parametrize("dataset", LOCAL_DATASET_DEFINITION) + def test_release(self, dataset, local_csvs): + partition_to_remove = "p2.csv" + pds = PartitionedDataset(str(local_csvs), dataset) + initial_load = pds.load() + assert partition_to_remove in initial_load + + (local_csvs / partition_to_remove).unlink() + cached_load = pds.load() + assert initial_load.keys() == cached_load.keys() + + pds.release() + load_after_release = pds.load() + assert initial_load.keys() ^ load_after_release.keys() == {partition_to_remove} + + @pytest.mark.parametrize("dataset", LOCAL_DATASET_DEFINITION) + def test_describe(self, dataset): + path = str(Path.cwd()) + pds = PartitionedDataset(path, dataset) + + assert f"path={path}" in str(pds) + assert "dataset_type=CSVDataset" in str(pds) + assert "dataset_config" in str(pds) + + def test_load_args(self, mocker): + fake_partition_name = "fake_partition" + mocked_filesystem = mocker.patch("fsspec.filesystem") + mocked_find = mocked_filesystem.return_value.find + mocked_find.return_value = [fake_partition_name] + + path = str(Path.cwd()) + load_args = {"maxdepth": 42, "withdirs": True} + pds = PartitionedDataset(path, "pandas.CSVDataset", load_args=load_args) + mocker.patch.object(pds, "_path_to_partition", return_value=fake_partition_name) + + assert pds.load().keys() == {fake_partition_name} + mocked_find.assert_called_once_with(path, **load_args) + + @pytest.mark.parametrize( + "credentials,expected_pds_creds,expected_dataset_creds", + [({"cred": "common"}, {"cred": "common"}, {"cred": "common"}), (None, {}, {})], + ) + def test_credentials( + self, mocker, credentials, expected_pds_creds, expected_dataset_creds + ): + mocked_filesystem = mocker.patch("fsspec.filesystem") + path = str(Path.cwd()) + pds = PartitionedDataset(path, "pandas.CSVDataset", credentials=credentials) + + assert mocked_filesystem.call_count == 2 + mocked_filesystem.assert_called_with("file", **expected_pds_creds) + if expected_dataset_creds: + assert pds._dataset_config[CREDENTIALS_KEY] == expected_dataset_creds + else: + assert CREDENTIALS_KEY not in pds._dataset_config + + str_repr = str(pds) + + def _assert_not_in_repr(value): + if isinstance(value, dict): + for k_, v_ in value.items(): + _assert_not_in_repr(k_) + _assert_not_in_repr(v_) + if value is not None: + assert str(value) not in str_repr + + _assert_not_in_repr(credentials) + + def test_fs_args(self, mocker): + fs_args = {"foo": "bar"} + + mocked_filesystem = mocker.patch("fsspec.filesystem") + path = str(Path.cwd()) + pds = PartitionedDataset(path, "pandas.CSVDataset", fs_args=fs_args) + + assert mocked_filesystem.call_count == 2 + mocked_filesystem.assert_called_with("file", **fs_args) + assert pds._dataset_config["fs_args"] == fs_args + + @pytest.mark.parametrize("dataset", ["pandas.ParquetDataset", ParquetDataset]) + def test_invalid_dataset(self, dataset, local_csvs): + pds = PartitionedDataset(str(local_csvs), dataset) + loaded_partitions = pds.load() + + for partition, df_loader in loaded_partitions.items(): + pattern = r"Failed while loading data from data set ParquetDataset(.*)" + with pytest.raises(DatasetError, match=pattern) as exc_info: + df_loader() + error_message = str(exc_info.value) + assert ( + "Either the file is corrupted or this is not a parquet file" + in error_message + ) + assert str(partition) in error_message + + @pytest.mark.parametrize( + "dataset_config,error_pattern", + [ + ("UndefinedDatasetType", "Class 'UndefinedDatasetType' not found"), + ( + "missing.module.UndefinedDatasetType", + r"Class 'missing\.module\.UndefinedDatasetType' not found", + ), + ( + FakeDataset, + r"Dataset type 'tests\.partitions\.test_partitioned_dataset\.FakeDataset' " + r"is invalid\: all data set types must extend 'AbstractDataset'", + ), + ({}, "'type' is missing from dataset catalog configuration"), + ], + ) + def test_invalid_dataset_config(self, dataset_config, error_pattern): + with pytest.raises(DatasetError, match=error_pattern): + PartitionedDataset(str(Path.cwd()), dataset_config) + + @pytest.mark.parametrize( + "dataset_config", + [ + {"type": CSVDataset, "versioned": True}, + {"type": "pandas.CSVDataset", "versioned": True}, + ], + ) + def test_versioned_dataset_not_allowed(self, dataset_config): + pattern = ( + "'PartitionedDataset' does not support versioning of the underlying " + "dataset. Please remove 'versioned' flag from the dataset definition." + ) + with pytest.raises(DatasetError, match=re.escape(pattern)): + PartitionedDataset(str(Path.cwd()), dataset_config) + + def test_no_partitions(self, tmpdir): + pds = PartitionedDataset(str(tmpdir), "pandas.CSVDataset") + + pattern = re.escape(f"No partitions found in '{tmpdir}'") + with pytest.raises(DatasetError, match=pattern): + pds.load() + + @pytest.mark.parametrize( + "pds_config,filepath_arg", + [ + ( + { + "path": str(Path.cwd()), + "dataset": {"type": CSVDataset, "filepath": "fake_path"}, + }, + "filepath", + ), + ( + { + "path": str(Path.cwd()), + "dataset": {"type": CSVDataset, "other_arg": "fake_path"}, + "filepath_arg": "other_arg", + }, + "other_arg", + ), + ], + ) + def test_filepath_arg_warning(self, pds_config, filepath_arg): + pattern = ( + f"'{filepath_arg}' key must not be specified in the dataset definition as it " + f"will be overwritten by partition path" + ) + with pytest.warns(UserWarning, match=re.escape(pattern)): + PartitionedDataset(**pds_config) + + def test_credentials_log_warning(self, caplog): + """Check that the warning is logged if the dataset credentials will overwrite + the top-level ones""" + pds = PartitionedDataset( + path=str(Path.cwd()), + dataset={"type": CSVDataset, "credentials": {"secret": "dataset"}}, + credentials={"secret": "global"}, + ) + log_message = KEY_PROPAGATION_WARNING % { + "keys": "credentials", + "target": "underlying dataset", + } + assert caplog.record_tuples == [("kedro.io.core", logging.WARNING, log_message)] + assert pds._dataset_config["credentials"] == {"secret": "dataset"} + + def test_fs_args_log_warning(self, caplog): + """Check that the warning is logged if the dataset filesystem + arguments will overwrite the top-level ones""" + pds = PartitionedDataset( + path=str(Path.cwd()), + dataset={"type": CSVDataset, "fs_args": {"args": "dataset"}}, + fs_args={"args": "dataset"}, + ) + log_message = KEY_PROPAGATION_WARNING % { + "keys": "filesystem arguments", + "target": "underlying dataset", + } + assert caplog.record_tuples == [("kedro.io.core", logging.WARNING, log_message)] + assert pds._dataset_config["fs_args"] == {"args": "dataset"} + + @pytest.mark.parametrize( + "pds_config,expected_ds_creds,global_creds", + [ + ( + {"dataset": "pandas.CSVDataset", "credentials": {"secret": "global"}}, + {"secret": "global"}, + {"secret": "global"}, + ), + ( + { + "dataset": { + "type": CSVDataset, + "credentials": {"secret": "expected"}, + }, + }, + {"secret": "expected"}, + {}, + ), + ( + { + "dataset": {"type": CSVDataset, "credentials": None}, + "credentials": {"secret": "global"}, + }, + None, + {"secret": "global"}, + ), + ( + { + "dataset": { + "type": CSVDataset, + "credentials": {"secret": "expected"}, + }, + "credentials": {"secret": "global"}, + }, + {"secret": "expected"}, + {"secret": "global"}, + ), + ], + ) + def test_dataset_creds(self, pds_config, expected_ds_creds, global_creds): + """Check that global credentials do not interfere dataset credentials.""" + pds = PartitionedDataset(path=str(Path.cwd()), **pds_config) + assert pds._dataset_config["credentials"] == expected_ds_creds + assert pds._credentials == global_creds + + +BUCKET_NAME = "fake_bucket_name" +S3_DATASET_DEFINITION = [ + "pandas.CSVDataset", + "kedro_datasets.pandas.CSVDataset", + CSVDataset, + {"type": "kedro_datasets.pandas.CSVDataset", "save_args": {"index": False}}, + {"type": CSVDataset}, +] + + +@pytest.fixture +def mocked_s3_bucket(): + """Create a bucket for testing using moto.""" + with mock_s3(): + conn = boto3.client( + "s3", + aws_access_key_id="fake_access_key", + aws_secret_access_key="fake_secret_key", + ) + conn.create_bucket(Bucket=BUCKET_NAME) + yield conn + + +@pytest.fixture +def mocked_csvs_in_s3(mocked_s3_bucket, partitioned_data_pandas): + prefix = "csvs" + for key, data in partitioned_data_pandas.items(): + mocked_s3_bucket.put_object( + Bucket=BUCKET_NAME, + Key=f"{prefix}/{key}", + Body=data.to_csv(index=False), + ) + return f"s3://{BUCKET_NAME}/{prefix}" + + +class TestPartitionedDatasetS3: + os.environ["AWS_ACCESS_KEY_ID"] = "FAKE_ACCESS_KEY" + os.environ["AWS_SECRET_ACCESS_KEY"] = "FAKE_SECRET_KEY" + + @pytest.mark.parametrize("dataset", S3_DATASET_DEFINITION) + def test_load(self, dataset, mocked_csvs_in_s3, partitioned_data_pandas): + pds = PartitionedDataset(mocked_csvs_in_s3, dataset) + loaded_partitions = pds.load() + + assert loaded_partitions.keys() == partitioned_data_pandas.keys() + for partition_id, load_func in loaded_partitions.items(): + df = load_func() + assert_frame_equal(df, partitioned_data_pandas[partition_id]) + + def test_load_s3a(self, mocked_csvs_in_s3, partitioned_data_pandas, mocker): + path = mocked_csvs_in_s3.split("://", 1)[1] + s3a_path = f"s3a://{path}" + # any type is fine as long as it passes isinstance check + # since _dataset_type is mocked later anyways + pds = PartitionedDataset(s3a_path, "pandas.CSVDataset") + assert pds._protocol == "s3a" + + mocked_ds = mocker.patch.object(pds, "_dataset_type") + mocked_ds.__name__ = "mocked" + loaded_partitions = pds.load() + + assert loaded_partitions.keys() == partitioned_data_pandas.keys() + assert mocked_ds.call_count == len(loaded_partitions) + expected = [ + mocker.call(filepath=f"{s3a_path}/{partition_id}") + for partition_id in loaded_partitions + ] + mocked_ds.assert_has_calls(expected, any_order=True) + + @pytest.mark.parametrize("dataset", S3_DATASET_DEFINITION) + def test_save(self, dataset, mocked_csvs_in_s3): + pds = PartitionedDataset(mocked_csvs_in_s3, dataset) + original_data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]}) + part_id = "new/data.csv" + pds.save({part_id: original_data}) + + s3 = s3fs.S3FileSystem() + assert s3.exists("/".join([mocked_csvs_in_s3, part_id])) + + loaded_partitions = pds.load() + assert part_id in loaded_partitions + reloaded_data = loaded_partitions[part_id]() + assert_frame_equal(reloaded_data, original_data) + + def test_save_s3a(self, mocked_csvs_in_s3, mocker): + """Test that save works in case of s3a protocol""" + path = mocked_csvs_in_s3.split("://", 1)[1] + s3a_path = f"s3a://{path}" + # any type is fine as long as it passes isinstance check + # since _dataset_type is mocked later anyways + pds = PartitionedDataset(s3a_path, "pandas.CSVDataset", filename_suffix=".csv") + assert pds._protocol == "s3a" + + mocked_ds = mocker.patch.object(pds, "_dataset_type") + mocked_ds.__name__ = "mocked" + new_partition = "new/data" + data = "data" + + pds.save({new_partition: data}) + mocked_ds.assert_called_once_with(filepath=f"{s3a_path}/{new_partition}.csv") + mocked_ds.return_value.save.assert_called_once_with(data) + + @pytest.mark.parametrize("dataset", ["pandas.CSVDataset", "pandas.HDFDataset"]) + def test_exists(self, dataset, mocked_csvs_in_s3): + assert PartitionedDataset(mocked_csvs_in_s3, dataset).exists() + + empty_folder = "/".join([mocked_csvs_in_s3, "empty", "folder"]) + assert not PartitionedDataset(empty_folder, dataset).exists() + + s3fs.S3FileSystem().mkdir(empty_folder) + assert not PartitionedDataset(empty_folder, dataset).exists() + + @pytest.mark.parametrize("dataset", S3_DATASET_DEFINITION) + def test_release(self, dataset, mocked_csvs_in_s3): + partition_to_remove = "p2.csv" + pds = PartitionedDataset(mocked_csvs_in_s3, dataset) + initial_load = pds.load() + assert partition_to_remove in initial_load + + s3 = s3fs.S3FileSystem() + s3.rm("/".join([mocked_csvs_in_s3, partition_to_remove])) + cached_load = pds.load() + assert initial_load.keys() == cached_load.keys() + + pds.release() + load_after_release = pds.load() + assert initial_load.keys() ^ load_after_release.keys() == {partition_to_remove} + + @pytest.mark.parametrize("dataset", S3_DATASET_DEFINITION) + def test_describe(self, dataset): + path = f"s3://{BUCKET_NAME}/foo/bar" + pds = PartitionedDataset(path, dataset) + + assert f"path={path}" in str(pds) + assert "dataset_type=CSVDataset" in str(pds) + assert "dataset_config" in str(pds) From 2ad6b14a80b364cb492acd273ab4d5c013968ccd Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Thu, 5 Oct 2023 16:47:40 +0100 Subject: [PATCH 2/9] Run only tensorflow tests Signed-off-by: Merel Theisen --- .github/workflows/unit-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 5f479afa5..59abdd3ab 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -51,7 +51,7 @@ jobs: run: pip freeze - name: Run unit tests for Linux / all plugins if: inputs.os != 'windows-latest' - run: make plugin=${{ inputs.plugin }} test + run: pytest tests/tensorflow/test_tensorflow_model_dataset.py - name: Run unit tests for Windows / kedro-airflow, kedro-docker, kedro-telemetry if: inputs.os == 'windows-latest' && inputs.plugin != 'kedro-datasets' run: | From 99a2ee82ed3fc2f79ee093d81a206a42fbb3c973 Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Thu, 5 Oct 2023 16:55:12 +0100 Subject: [PATCH 3/9] Run only tensorflow tests Signed-off-by: Merel Theisen --- .github/workflows/unit-tests.yml | 2 +- Makefile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 59abdd3ab..5f479afa5 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -51,7 +51,7 @@ jobs: run: pip freeze - name: Run unit tests for Linux / all plugins if: inputs.os != 'windows-latest' - run: pytest tests/tensorflow/test_tensorflow_model_dataset.py + run: make plugin=${{ inputs.plugin }} test - name: Run unit tests for Windows / kedro-airflow, kedro-docker, kedro-telemetry if: inputs.os == 'windows-latest' && inputs.plugin != 'kedro-datasets' run: | diff --git a/Makefile b/Makefile index 03e74bec0..5136c7233 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ lint: pre-commit run -a --hook-stage manual ruff-$(plugin) && pre-commit run trailing-whitespace --all-files && pre-commit run end-of-file-fixer --all-files && pre-commit run check-yaml --all-files && pre-commit run check-added-large-files --all-files && pre-commit run check-case-conflict --all-files && pre-commit run check-merge-conflict --all-files && pre-commit run debug-statements --all-files && pre-commit run black-$(plugin) --all-files --hook-stage manual && pre-commit run secret_scan --all-files --hook-stage manual && pre-commit run bandit --all-files --hook-stage manual test: - cd $(plugin) && pytest tests --cov-config pyproject.toml --numprocesses 4 --dist loadfile + cd $(plugin) && pytest tests/tensorflow/test_tensorflow_model_dataset.py test-sequential: cd $(plugin) && pytest tests --cov-config pyproject.toml From aefd444b2b56bf29aca11ed216ffc098d10e3197 Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Thu, 5 Oct 2023 17:14:45 +0100 Subject: [PATCH 4/9] Run tensorflow datasets separately Signed-off-by: Merel Theisen --- .github/workflows/unit-tests.yml | 7 +++++-- Makefile | 6 +++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 5f479afa5..6e3e2ecb7 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -49,9 +49,12 @@ jobs: pip install ".[test]" - name: pip freeze run: pip freeze - - name: Run unit tests for Linux / all plugins - if: inputs.os != 'windows-latest' + - name: Run unit tests for Linux / kedro-airflow, kedro-docker, kedro-telemetry + if: inputs.os != 'windows-latest' && inputs.plugin != 'kedro-datasets' run: make plugin=${{ inputs.plugin }} test + - name: Run unit tests for Linux / kedro-datasets + if: inputs.os != 'windows-latest' && inputs.plugin == 'kedro-datasets' + run: make dataset-tests - name: Run unit tests for Windows / kedro-airflow, kedro-docker, kedro-telemetry if: inputs.os == 'windows-latest' && inputs.plugin != 'kedro-datasets' run: | diff --git a/Makefile b/Makefile index 5136c7233..535961a72 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,11 @@ lint: pre-commit run -a --hook-stage manual ruff-$(plugin) && pre-commit run trailing-whitespace --all-files && pre-commit run end-of-file-fixer --all-files && pre-commit run check-yaml --all-files && pre-commit run check-added-large-files --all-files && pre-commit run check-case-conflict --all-files && pre-commit run check-merge-conflict --all-files && pre-commit run debug-statements --all-files && pre-commit run black-$(plugin) --all-files --hook-stage manual && pre-commit run secret_scan --all-files --hook-stage manual && pre-commit run bandit --all-files --hook-stage manual test: - cd $(plugin) && pytest tests/tensorflow/test_tensorflow_model_dataset.py + cd $(plugin) && pytest tests --cov-config pyproject.toml --numprocesses 4 --dist loadfile + +dataset-tests: + cd kedro-datasets && pytest tests --cov-config pyproject.toml --numprocesses 4 --dist loadfile --ignore tests/tensorflow + cd kedro-datasets && pytest tests/tensorflow/test_tensorflow_model_dataset.py test-sequential: cd $(plugin) && pytest tests --cov-config pyproject.toml From fbd72d2f854c2ddd0e968d62ad2089fe5059d5a2 Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Thu, 5 Oct 2023 17:24:55 +0100 Subject: [PATCH 5/9] Ignore tensorflow tests in coverage Signed-off-by: Merel Theisen --- kedro-datasets/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index d5be97bbc..e485149ed 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -31,7 +31,7 @@ version = {attr = "kedro_datasets.__version__"} [tool.coverage.report] fail_under = 100 show_missing = true -omit = ["tests/*", "kedro_datasets/holoviews/*", "kedro_datasets/snowflake/*"] +omit = ["tests/*", "kedro_datasets/holoviews/*", "kedro_datasets/snowflake/*", "kedro_datasets/tensorflow/*"] exclude_lines = ["pragma: no cover", "raise NotImplementedError"] [tool.pytest.ini_options] From c6e4a8f88ec397f80238e7323757c47f47f89992 Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Thu, 5 Oct 2023 17:34:55 +0100 Subject: [PATCH 6/9] Don't check coverage for only tensorflow tests Signed-off-by: Merel Theisen --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 535961a72..be475b09b 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,7 @@ test: dataset-tests: cd kedro-datasets && pytest tests --cov-config pyproject.toml --numprocesses 4 --dist loadfile --ignore tests/tensorflow - cd kedro-datasets && pytest tests/tensorflow/test_tensorflow_model_dataset.py + cd kedro-datasets && pytest tests/tensorflow/test_tensorflow_model_dataset.py --no-cov test-sequential: cd $(plugin) && pytest tests --cov-config pyproject.toml From 14f13653582bc110b00190bcd29f425eace0d32b Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Thu, 5 Oct 2023 17:46:31 +0100 Subject: [PATCH 7/9] Remove new datasets again, to keep this PR to build changes only Signed-off-by: Merel Theisen --- kedro-datasets/RELEASE.md | 2 - .../kedro_datasets/partitions/__init__.py | 11 - .../partitions/incremental_dataset.py | 237 -------- .../partitions/partitioned_dataset.py | 329 ----------- kedro-datasets/tests/partitions/__init__.py | 0 .../partitions/test_incremental_dataset.py | 508 ---------------- .../partitions/test_partitioned_dataset.py | 540 ------------------ 7 files changed, 1627 deletions(-) delete mode 100644 kedro-datasets/kedro_datasets/partitions/__init__.py delete mode 100644 kedro-datasets/kedro_datasets/partitions/incremental_dataset.py delete mode 100644 kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py delete mode 100644 kedro-datasets/tests/partitions/__init__.py delete mode 100644 kedro-datasets/tests/partitions/test_incremental_dataset.py delete mode 100644 kedro-datasets/tests/partitions/test_partitioned_dataset.py diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index 168e7d72f..9c6661fda 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -1,7 +1,5 @@ # Upcoming Release ## Major features and improvements -* Moved `PartitionedDataSet` and `IncrementalDataSet` from the core Kedro repo to `kedro-datasets`. - ## Bug fixes and other changes ## Upcoming deprecations for Kedro-Datasets 2.0.0 * Renamed dataset and error classes, in accordance with the [Kedro lexicon](https://github.com/kedro-org/kedro/wiki/Kedro-documentation-style-guide#kedro-lexicon). Dataset classes ending with "DataSet" are deprecated and will be removed in 2.0.0. diff --git a/kedro-datasets/kedro_datasets/partitions/__init__.py b/kedro-datasets/kedro_datasets/partitions/__init__.py deleted file mode 100644 index 2f464a907..000000000 --- a/kedro-datasets/kedro_datasets/partitions/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""``AbstractDataset`` implementation to load/save data in partitions -from/to any underlying Dataset format. -""" - -__all__ = ["PartitionedDataset", "IncrementalDataset"] - -from contextlib import suppress - -with suppress(ImportError): - from .incremental_dataset import IncrementalDataset - from .partitioned_dataset import PartitionedDataset diff --git a/kedro-datasets/kedro_datasets/partitions/incremental_dataset.py b/kedro-datasets/kedro_datasets/partitions/incremental_dataset.py deleted file mode 100644 index 9623a5893..000000000 --- a/kedro-datasets/kedro_datasets/partitions/incremental_dataset.py +++ /dev/null @@ -1,237 +0,0 @@ -"""``IncrementalDataset`` inherits from ``PartitionedDataset``, which loads -and saves partitioned file-like data using the underlying dataset -definition. ``IncrementalDataset`` also stores the information about the last -processed partition in so-called `checkpoint` that is persisted to the location -of the data partitions by default, so that subsequent pipeline run loads only -new partitions past the checkpoint.It also uses `fsspec` for filesystem level operations. -""" -from __future__ import annotations - -import operator -from copy import deepcopy -from typing import Any, Callable - -from cachetools import cachedmethod -from kedro.io.core import ( - VERSION_KEY, - VERSIONED_FLAG_KEY, - AbstractDataset, - DatasetError, - parse_dataset_definition, -) -from kedro.io.data_catalog import CREDENTIALS_KEY -from kedro.utils import load_obj - -from .partitioned_dataset import KEY_PROPAGATION_WARNING, PartitionedDataset - - -class IncrementalDataset(PartitionedDataset): - """``IncrementalDataset`` inherits from ``PartitionedDataset``, which loads - and saves partitioned file-like data using the underlying dataset - definition. For filesystem level operations it uses `fsspec`: - https://github.com/intake/filesystem_spec. ``IncrementalDataset`` also stores - the information about the last processed partition in so-called `checkpoint` - that is persisted to the location of the data partitions by default, so that - subsequent pipeline run loads only new partitions past the checkpoint. - - Example: - :: - - >>> from kedro_datasets.partitions import IncrementalDataset - >>> - >>> # these credentials will be passed to: - >>> # a) 'fsspec.filesystem()' call, - >>> # b) the dataset initializer, - >>> # c) the checkpoint initializer - >>> credentials = {"key1": "secret1", "key2": "secret2"} - >>> - >>> data_set = IncrementalDataset( - >>> path="s3://bucket-name/path/to/folder", - >>> dataset="pandas.CSVDataset", - >>> credentials=credentials - >>> ) - >>> loaded = data_set.load() # loads all available partitions - >>> # assert isinstance(loaded, dict) - >>> - >>> data_set.confirm() # update checkpoint value to the last processed partition ID - >>> reloaded = data_set.load() # still loads all available partitions - >>> - >>> data_set.release() # clears load cache - >>> # returns an empty dictionary as no new partitions were added - >>> data_set.load() - """ - - DEFAULT_CHECKPOINT_TYPE = "kedro_datasets.text.TextDataSet" - DEFAULT_CHECKPOINT_FILENAME = "CHECKPOINT" - - def __init__( # noqa: PLR0913 - self, - path: str, - dataset: str | type[AbstractDataset] | dict[str, Any], - checkpoint: str | dict[str, Any] | None = None, - filepath_arg: str = "filepath", - filename_suffix: str = "", - credentials: dict[str, Any] = None, - load_args: dict[str, Any] = None, - fs_args: dict[str, Any] = None, - metadata: dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``IncrementalDataset``. - - Args: - path: Path to the folder containing partitioned data. - If path starts with the protocol (e.g., ``s3://``) then the - corresponding ``fsspec`` concrete filesystem implementation will - be used. If protocol is not specified, - ``fsspec.implementations.local.LocalFileSystem`` will be used. - **Note:** Some concrete implementations are bundled with ``fsspec``, - while others (like ``s3`` or ``gcs``) must be installed separately - prior to usage of the ``PartitionedDataset``. - dataset: Underlying dataset definition. This is used to instantiate - the dataset for each file located inside the ``path``. - Accepted formats are: - a) object of a class that inherits from ``AbstractDataset`` - b) a string representing a fully qualified class name to such class - c) a dictionary with ``type`` key pointing to a string from b), - other keys are passed to the Dataset initializer. - Credentials for the dataset can be explicitly specified in - this configuration. - checkpoint: Optional checkpoint configuration. Accepts a dictionary - with the corresponding dataset definition including ``filepath`` - (unlike ``dataset`` argument). Checkpoint configuration is - described here: - https://kedro.readthedocs.io/en/stable/data/kedro_io.html#checkpoint-configuration - Credentials for the checkpoint can be explicitly specified - in this configuration. - filepath_arg: Underlying dataset initializer argument that will - contain a path to each corresponding partition file. - If unspecified, defaults to "filepath". - filename_suffix: If specified, only partitions that end with this - string will be processed. - credentials: Protocol-specific options that will be passed to - ``fsspec.filesystem`` - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.filesystem, - the dataset dataset initializer and the checkpoint. If - the dataset or the checkpoint configuration contains explicit - credentials spec, then such spec will take precedence. - All possible credentials management scenarios are documented here: - https://kedro.readthedocs.io/en/stable/data/kedro_io.html#partitioned-dataset-credentials - load_args: Keyword arguments to be passed into ``find()`` method of - the filesystem implementation. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``). - metadata: Any arbitrary metadata. - This is ignored by Kedro, but may be consumed by users or external plugins. - - Raises: - DatasetError: If versioning is enabled for the underlying dataset. - """ - - super().__init__( - path=path, - dataset=dataset, - filepath_arg=filepath_arg, - filename_suffix=filename_suffix, - credentials=credentials, - load_args=load_args, - fs_args=fs_args, - ) - - self._checkpoint_config = self._parse_checkpoint_config(checkpoint) - self._force_checkpoint = self._checkpoint_config.pop("force_checkpoint", None) - self.metadata = metadata - - comparison_func = self._checkpoint_config.pop("comparison_func", operator.gt) - if isinstance(comparison_func, str): - comparison_func = load_obj(comparison_func) - self._comparison_func = comparison_func - - def _parse_checkpoint_config( - self, checkpoint_config: str | dict[str, Any] | None - ) -> dict[str, Any]: - checkpoint_config = deepcopy(checkpoint_config) - if isinstance(checkpoint_config, str): - checkpoint_config = {"force_checkpoint": checkpoint_config} - checkpoint_config = checkpoint_config or {} - - for key in {VERSION_KEY, VERSIONED_FLAG_KEY} & checkpoint_config.keys(): - raise DatasetError( - f"'{self.__class__.__name__}' does not support versioning of the " - f"checkpoint. Please remove '{key}' key from the checkpoint definition." - ) - - default_checkpoint_path = self._sep.join( - [self._normalized_path.rstrip(self._sep), self.DEFAULT_CHECKPOINT_FILENAME] - ) - default_config = { - "type": self.DEFAULT_CHECKPOINT_TYPE, - self._filepath_arg: default_checkpoint_path, - } - if self._credentials: - default_config[CREDENTIALS_KEY] = deepcopy(self._credentials) - - if CREDENTIALS_KEY in default_config.keys() & checkpoint_config.keys(): - self._logger.warning( - KEY_PROPAGATION_WARNING, - {"keys": CREDENTIALS_KEY, "target": "checkpoint"}, - ) - - return {**default_config, **checkpoint_config} - - @cachedmethod(cache=operator.attrgetter("_partition_cache")) - def _list_partitions(self) -> list[str]: - checkpoint = self._read_checkpoint() - checkpoint_path = self._filesystem._strip_protocol( - self._checkpoint_config[self._filepath_arg] - ) - - def _is_valid_partition(partition) -> bool: - if not partition.endswith(self._filename_suffix): - return False - if partition == checkpoint_path: - return False - if checkpoint is None: - # nothing was processed yet - return True - partition_id = self._path_to_partition(partition) - return self._comparison_func(partition_id, checkpoint) - - return sorted( - part - for part in self._filesystem.find(self._normalized_path, **self._load_args) - if _is_valid_partition(part) - ) - - @property - def _checkpoint(self) -> AbstractDataset: - type_, kwargs = parse_dataset_definition(self._checkpoint_config) - return type_(**kwargs) # type: ignore - - def _read_checkpoint(self) -> str | None: - if self._force_checkpoint is not None: - return self._force_checkpoint - try: - return self._checkpoint.load() - except DatasetError: - return None - - def _load(self) -> dict[str, Callable[[], Any]]: - partitions: dict[str, Any] = {} - - for partition in self._list_partitions(): - partition_id = self._path_to_partition(partition) - kwargs = deepcopy(self._dataset_config) - # join the protocol back since PySpark may rely on it - kwargs[self._filepath_arg] = self._join_protocol(partition) - partitions[partition_id] = self._dataset_type( # type: ignore - **kwargs - ).load() - - return partitions - - def confirm(self) -> None: - """Confirm the dataset by updating the checkpoint value to the latest - processed partition ID""" - partition_ids = [self._path_to_partition(p) for p in self._list_partitions()] - if partition_ids: - self._checkpoint.save(partition_ids[-1]) # checkpoint to last partition diff --git a/kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py b/kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py deleted file mode 100644 index 74242b113..000000000 --- a/kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py +++ /dev/null @@ -1,329 +0,0 @@ -"""``PartitionedDataset`` loads and saves partitioned file-like data using the -underlying dataset definition. It also uses `fsspec` for filesystem level operations. -""" -from __future__ import annotations - -import operator -from copy import deepcopy -from typing import Any, Callable, Dict -from urllib.parse import urlparse -from warnings import warn - -import fsspec -from cachetools import Cache, cachedmethod -from kedro.io.core import ( - VERSION_KEY, - VERSIONED_FLAG_KEY, - AbstractDataset, - DatasetError, - parse_dataset_definition, -) -from kedro.io.data_catalog import CREDENTIALS_KEY - -KEY_PROPAGATION_WARNING = ( - "Top-level %(keys)s will not propagate into the %(target)s since " - "%(keys)s were explicitly defined in the %(target)s config." -) - -S3_PROTOCOLS = ("s3", "s3a", "s3n") - - -class PartitionedDataset(AbstractDataset[Dict[str, Any], Dict[str, Callable[[], Any]]]): - """``PartitionedDataset`` loads and saves partitioned file-like data using the - underlying dataset definition. For filesystem level operations it uses `fsspec`: - https://github.com/intake/filesystem_spec. - - It also supports advanced features like - `lazy saving `_. - - Example usage for the - `YAML API `_: - - .. code-block:: yaml - - station_data: - type: PartitionedDataset - path: data/03_primary/station_data - dataset: - type: pandas.CSVDataset - load_args: - sep: '\\t' - save_args: - sep: '\\t' - index: true - filename_suffix: '.dat' - - Example usage for the - `Python API `_: - :: - - >>> import pandas as pd - >>> from kedro_datasets.partitions import PartitionedDataset - >>> - >>> # Create a fake pandas dataframe with 10 rows of data - >>> df = pd.DataFrame([{"DAY_OF_MONTH": str(i), "VALUE": i} for i in range(1, 11)]) - >>> - >>> # Convert it to a dict of pd.DataFrame with DAY_OF_MONTH as the dict key - >>> dict_df = { - day_of_month: df[df["DAY_OF_MONTH"] == day_of_month] - for day_of_month in df["DAY_OF_MONTH"] - } - >>> - >>> # Save it as small paritions with DAY_OF_MONTH as the partition key - >>> data_set = PartitionedDataset( - path="df_with_partition", - dataset="pandas.CSVDataset", - filename_suffix=".csv" - ) - >>> # This will create a folder `df_with_partition` and save multiple files - >>> # with the dict key + filename_suffix as filename, i.e. 1.csv, 2.csv etc. - >>> data_set.save(dict_df) - >>> - >>> # This will create lazy load functions instead of loading data into memory immediately. - >>> loaded = data_set.load() - >>> - >>> # Load all the partitions - >>> for partition_id, partition_load_func in loaded.items(): - # The actual function that loads the data - partition_data = partition_load_func() - >>> - >>> # Add the processing logic for individual partition HERE - >>> print(partition_data) - - You can also load multiple partitions from a remote storage and combine them - like this: - :: - - >>> import pandas as pd - >>> from kedro_datasets.partitions import PartitionedDataset - >>> - >>> # these credentials will be passed to both 'fsspec.filesystem()' call - >>> # and the dataset initializer - >>> credentials = {"key1": "secret1", "key2": "secret2"} - >>> - >>> data_set = PartitionedDataset( - path="s3://bucket-name/path/to/folder", - dataset="pandas.CSVDataset", - credentials=credentials - ) - >>> loaded = data_set.load() - >>> # assert isinstance(loaded, dict) - >>> - >>> combine_all = pd.DataFrame() - >>> - >>> for partition_id, partition_load_func in loaded.items(): - partition_data = partition_load_func() - combine_all = pd.concat( - [combine_all, partition_data], ignore_index=True, sort=True - ) - >>> - >>> new_data = pd.DataFrame({"new": [1, 2]}) - >>> # creates "s3://bucket-name/path/to/folder/new/partition.csv" - >>> data_set.save({"new/partition.csv": new_data}) - - """ - - def __init__( # noqa: PLR0913 - self, - path: str, - dataset: str | type[AbstractDataset] | dict[str, Any], - filepath_arg: str = "filepath", - filename_suffix: str = "", - credentials: dict[str, Any] = None, - load_args: dict[str, Any] = None, - fs_args: dict[str, Any] = None, - overwrite: bool = False, - metadata: dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``PartitionedDataset``. - - Args: - path: Path to the folder containing partitioned data. - If path starts with the protocol (e.g., ``s3://``) then the - corresponding ``fsspec`` concrete filesystem implementation will - be used. If protocol is not specified, - ``fsspec.implementations.local.LocalFileSystem`` will be used. - **Note:** Some concrete implementations are bundled with ``fsspec``, - while others (like ``s3`` or ``gcs``) must be installed separately - prior to usage of the ``PartitionedDataset``. - dataset: Underlying dataset definition. This is used to instantiate - the dataset for each file located inside the ``path``. - Accepted formats are: - a) object of a class that inherits from ``AbstractDataset`` - b) a string representing a fully qualified class name to such class - c) a dictionary with ``type`` key pointing to a string from b), - other keys are passed to the Dataset initializer. - Credentials for the dataset can be explicitly specified in - this configuration. - filepath_arg: Underlying dataset initializer argument that will - contain a path to each corresponding partition file. - If unspecified, defaults to "filepath". - filename_suffix: If specified, only partitions that end with this - string will be processed. - credentials: Protocol-specific options that will be passed to - ``fsspec.filesystem`` - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.filesystem - and the dataset initializer. If the dataset config contains - explicit credentials spec, then such spec will take precedence. - All possible credentials management scenarios are documented here: - https://kedro.readthedocs.io/en/stable/data/kedro_io.html#partitioned-dataset-credentials - load_args: Keyword arguments to be passed into ``find()`` method of - the filesystem implementation. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``) - overwrite: If True, any existing partitions will be removed. - metadata: Any arbitrary metadata. - This is ignored by Kedro, but may be consumed by users or external plugins. - - Raises: - DatasetError: If versioning is enabled for the underlying dataset. - """ - from fsspec.utils import infer_storage_options # for performance reasons - - super().__init__() - - self._path = path - self._filename_suffix = filename_suffix - self._overwrite = overwrite - self._protocol = infer_storage_options(self._path)["protocol"] - self._partition_cache: Cache = Cache(maxsize=1) - self.metadata = metadata - - dataset = dataset if isinstance(dataset, dict) else {"type": dataset} - self._dataset_type, self._dataset_config = parse_dataset_definition(dataset) - if VERSION_KEY in self._dataset_config: - raise DatasetError( - f"'{self.__class__.__name__}' does not support versioning of the " - f"underlying dataset. Please remove '{VERSIONED_FLAG_KEY}' flag from " - f"the dataset definition." - ) - - if credentials: - if CREDENTIALS_KEY in self._dataset_config: - self._logger.warning( - KEY_PROPAGATION_WARNING, - {"keys": CREDENTIALS_KEY, "target": "underlying dataset"}, - ) - else: - self._dataset_config[CREDENTIALS_KEY] = deepcopy(credentials) - - self._credentials = deepcopy(credentials) or {} - - self._fs_args = deepcopy(fs_args) or {} - if self._fs_args: - if "fs_args" in self._dataset_config: - self._logger.warning( - KEY_PROPAGATION_WARNING, - {"keys": "filesystem arguments", "target": "underlying dataset"}, - ) - else: - self._dataset_config["fs_args"] = deepcopy(self._fs_args) - - self._filepath_arg = filepath_arg - if self._filepath_arg in self._dataset_config: - warn( - f"'{self._filepath_arg}' key must not be specified in the dataset " - f"definition as it will be overwritten by partition path" - ) - - self._load_args = deepcopy(load_args) or {} - self._sep = self._filesystem.sep - # since some filesystem implementations may implement a global cache - self._invalidate_caches() - - @property - def _filesystem(self): - protocol = "s3" if self._protocol in S3_PROTOCOLS else self._protocol - return fsspec.filesystem(protocol, **self._credentials, **self._fs_args) - - @property - def _normalized_path(self) -> str: - if self._protocol in S3_PROTOCOLS: - return urlparse(self._path)._replace(scheme="s3").geturl() - return self._path - - @cachedmethod(cache=operator.attrgetter("_partition_cache")) - def _list_partitions(self) -> list[str]: - return [ - path - for path in self._filesystem.find(self._normalized_path, **self._load_args) - if path.endswith(self._filename_suffix) - ] - - def _join_protocol(self, path: str) -> str: - protocol_prefix = f"{self._protocol}://" - if self._path.startswith(protocol_prefix) and not path.startswith( - protocol_prefix - ): - return f"{protocol_prefix}{path}" - return path - - def _partition_to_path(self, path: str): - dir_path = self._path.rstrip(self._sep) - path = path.lstrip(self._sep) - full_path = self._sep.join([dir_path, path]) + self._filename_suffix - return full_path - - def _path_to_partition(self, path: str) -> str: - dir_path = self._filesystem._strip_protocol(self._normalized_path) - path = path.split(dir_path, 1).pop().lstrip(self._sep) - if self._filename_suffix and path.endswith(self._filename_suffix): - path = path[: -len(self._filename_suffix)] - return path - - def _load(self) -> dict[str, Callable[[], Any]]: - partitions = {} - - for partition in self._list_partitions(): - kwargs = deepcopy(self._dataset_config) - # join the protocol back since PySpark may rely on it - kwargs[self._filepath_arg] = self._join_protocol(partition) - dataset = self._dataset_type(**kwargs) # type: ignore - partition_id = self._path_to_partition(partition) - partitions[partition_id] = dataset.load - - if not partitions: - raise DatasetError(f"No partitions found in '{self._path}'") - - return partitions - - def _save(self, data: dict[str, Any]) -> None: - if self._overwrite and self._filesystem.exists(self._normalized_path): - self._filesystem.rm(self._normalized_path, recursive=True) - - for partition_id, partition_data in sorted(data.items()): - kwargs = deepcopy(self._dataset_config) - partition = self._partition_to_path(partition_id) - # join the protocol back since tools like PySpark may rely on it - kwargs[self._filepath_arg] = self._join_protocol(partition) - dataset = self._dataset_type(**kwargs) # type: ignore - if callable(partition_data): - partition_data = partition_data() # noqa: PLW2901 - dataset.save(partition_data) - self._invalidate_caches() - - def _describe(self) -> dict[str, Any]: - clean_dataset_config = ( - {k: v for k, v in self._dataset_config.items() if k != CREDENTIALS_KEY} - if isinstance(self._dataset_config, dict) - else self._dataset_config - ) - return { - "path": self._path, - "dataset_type": self._dataset_type.__name__, - "dataset_config": clean_dataset_config, - } - - def _invalidate_caches(self) -> None: - self._partition_cache.clear() - self._filesystem.invalidate_cache(self._normalized_path) - - def _exists(self) -> bool: - return bool(self._list_partitions()) - - def _release(self) -> None: - super()._release() - self._invalidate_caches() diff --git a/kedro-datasets/tests/partitions/__init__.py b/kedro-datasets/tests/partitions/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/kedro-datasets/tests/partitions/test_incremental_dataset.py b/kedro-datasets/tests/partitions/test_incremental_dataset.py deleted file mode 100644 index 7c880f88d..000000000 --- a/kedro-datasets/tests/partitions/test_incremental_dataset.py +++ /dev/null @@ -1,508 +0,0 @@ -from __future__ import annotations - -import os -import re -from pathlib import Path -from typing import Any - -import boto3 -import pandas as pd -import pytest -from kedro.io.core import AbstractDataset, DatasetError -from kedro.io.data_catalog import CREDENTIALS_KEY -from moto import mock_s3 -from pandas.util.testing import assert_frame_equal - -from kedro_datasets.partitions import IncrementalDataset -from kedro_datasets.pickle import PickleDataSet -from kedro_datasets.text import TextDataSet - -DATASET = "kedro_datasets.pandas.csv_dataset.CSVDataSet" - - -@pytest.fixture -def partitioned_data_pandas(): - return { - f"p{counter:02d}/data.csv": pd.DataFrame( - {"part": counter, "col": list(range(counter + 1))} - ) - for counter in range(5) - } - - -@pytest.fixture -def local_csvs(tmp_path, partitioned_data_pandas): - local_dir = Path(tmp_path / "csvs") - local_dir.mkdir() - - for k, data in partitioned_data_pandas.items(): - path = local_dir / k - path.parent.mkdir(parents=True) - data.to_csv(str(path), index=False) - return local_dir - - -class DummyDataset(AbstractDataset): # pragma: no cover - def __init__(self, filepath): - pass - - def _describe(self) -> dict[str, Any]: - return {"dummy": True} - - def _load(self) -> Any: - pass - - def _save(self, data: Any) -> None: - pass - - -def dummy_gt_func(value1: str, value2: str): - return value1 > value2 - - -def dummy_lt_func(value1: str, value2: str): - return value1 < value2 - - -class TestIncrementalDatasetLocal: - def test_load_and_confirm(self, local_csvs, partitioned_data_pandas): - """Test the standard flow for loading, confirming and reloading - an IncrementalDataset""" - pds = IncrementalDataset(str(local_csvs), DATASET) - loaded = pds.load() - assert loaded.keys() == partitioned_data_pandas.keys() - for partition_id, data in loaded.items(): - assert_frame_equal(data, partitioned_data_pandas[partition_id]) - - checkpoint_path = local_csvs / pds.DEFAULT_CHECKPOINT_FILENAME - assert not checkpoint_path.exists() - pds.confirm() - assert checkpoint_path.is_file() - assert checkpoint_path.read_text() == pds._read_checkpoint() == "p04/data.csv" - - reloaded = pds.load() - assert reloaded.keys() == loaded.keys() - - pds.release() - reloaded_after_release = pds.load() - assert not reloaded_after_release - - def test_save(self, local_csvs): - """Test saving a new partition into an IncrementalDataset""" - df = pd.DataFrame({"dummy": [1, 2, 3]}) - new_partition_key = "p05/data.csv" - new_partition_path = local_csvs / new_partition_key - pds = IncrementalDataset(str(local_csvs), DATASET) - - assert not new_partition_path.exists() - assert new_partition_key not in pds.load() - - pds.save({new_partition_key: df}) - assert new_partition_path.exists() - loaded = pds.load() - assert_frame_equal(loaded[new_partition_key], df) - - @pytest.mark.parametrize( - "filename_suffix,expected_partitions", - [ - ( - "", - { - "p00/data.csv", - "p01/data.csv", - "p02/data.csv", - "p03/data.csv", - "p04/data.csv", - }, - ), - (".csv", {"p00/data", "p01/data", "p02/data", "p03/data", "p04/data"}), - (".fake", set()), - ], - ) - def test_filename_suffix(self, filename_suffix, expected_partitions, local_csvs): - """Test how specifying filename_suffix affects the available - partitions and their names""" - pds = IncrementalDataset( - str(local_csvs), DATASET, filename_suffix=filename_suffix - ) - loaded = pds.load() - assert loaded.keys() == expected_partitions - - @pytest.mark.parametrize( - "forced_checkpoint,expected_partitions", - [ - ( - "", - { - "p00/data.csv", - "p01/data.csv", - "p02/data.csv", - "p03/data.csv", - "p04/data.csv", - }, - ), - ( - "p00/data.csv", - {"p01/data.csv", "p02/data.csv", "p03/data.csv", "p04/data.csv"}, - ), - ("p03/data.csv", {"p04/data.csv"}), - ], - ) - def test_force_checkpoint_no_checkpoint_file( - self, forced_checkpoint, expected_partitions, local_csvs - ): - """Test how forcing checkpoint value affects the available partitions - if the checkpoint file does not exist""" - pds = IncrementalDataset(str(local_csvs), DATASET, checkpoint=forced_checkpoint) - loaded = pds.load() - assert loaded.keys() == expected_partitions - - confirm_path = local_csvs / pds.DEFAULT_CHECKPOINT_FILENAME - assert not confirm_path.exists() - pds.confirm() - assert confirm_path.is_file() - assert confirm_path.read_text() == max(expected_partitions) - - @pytest.mark.parametrize( - "forced_checkpoint,expected_partitions", - [ - ( - "", - { - "p00/data.csv", - "p01/data.csv", - "p02/data.csv", - "p03/data.csv", - "p04/data.csv", - }, - ), - ( - "p00/data.csv", - {"p01/data.csv", "p02/data.csv", "p03/data.csv", "p04/data.csv"}, - ), - ("p03/data.csv", {"p04/data.csv"}), - ], - ) - def test_force_checkpoint_checkpoint_file_exists( - self, forced_checkpoint, expected_partitions, local_csvs - ): - """Test how forcing checkpoint value affects the available partitions - if the checkpoint file exists""" - IncrementalDataset(str(local_csvs), DATASET).confirm() - checkpoint = local_csvs / IncrementalDataset.DEFAULT_CHECKPOINT_FILENAME - assert checkpoint.read_text() == "p04/data.csv" - - pds = IncrementalDataset(str(local_csvs), DATASET, checkpoint=forced_checkpoint) - assert pds._checkpoint.exists() - loaded = pds.load() - assert loaded.keys() == expected_partitions - - @pytest.mark.parametrize( - "forced_checkpoint", ["p04/data.csv", "p10/data.csv", "p100/data.csv"] - ) - def test_force_checkpoint_no_partitions(self, forced_checkpoint, local_csvs): - """Test that forcing the checkpoint to certain values results in no - partitions being returned""" - pds = IncrementalDataset(str(local_csvs), DATASET, checkpoint=forced_checkpoint) - loaded = pds.load() - assert not loaded - - confirm_path = local_csvs / pds.DEFAULT_CHECKPOINT_FILENAME - assert not confirm_path.exists() - pds.confirm() - # confirming with no partitions available must have no effect - assert not confirm_path.exists() - - def test_checkpoint_path(self, local_csvs, partitioned_data_pandas): - """Test configuring a different checkpoint path""" - checkpoint_path = local_csvs / "checkpoint_folder" / "checkpoint_file" - assert not checkpoint_path.exists() - - IncrementalDataset( - str(local_csvs), DATASET, checkpoint={"filepath": str(checkpoint_path)} - ).confirm() - assert checkpoint_path.is_file() - assert checkpoint_path.read_text() == max(partitioned_data_pandas) - - @pytest.mark.parametrize( - "checkpoint_config,expected_checkpoint_class", - [ - (None, TextDataSet), - ({"type": "kedro_datasets.pickle.PickleDataSet"}, PickleDataSet), - ( - {"type": "tests.partitions.test_incremental_dataset.DummyDataset"}, - DummyDataset, - ), - ], - ) - def test_checkpoint_type( - self, tmp_path, checkpoint_config, expected_checkpoint_class - ): - """Test configuring a different checkpoint dataset type""" - pds = IncrementalDataset(str(tmp_path), DATASET, checkpoint=checkpoint_config) - assert isinstance(pds._checkpoint, expected_checkpoint_class) - - @pytest.mark.parametrize( - "checkpoint_config,error_pattern", - [ - ( - {"versioned": True}, - "'IncrementalDataset' does not support versioning " - "of the checkpoint. Please remove 'versioned' key from the " - "checkpoint definition.", - ), - ( - {"version": None}, - "'IncrementalDataset' does not support versioning " - "of the checkpoint. Please remove 'version' key from the " - "checkpoint definition.", - ), - ], - ) - def test_version_not_allowed(self, tmp_path, checkpoint_config, error_pattern): - """Test that invalid checkpoint configurations raise expected errors""" - with pytest.raises(DatasetError, match=re.escape(error_pattern)): - IncrementalDataset(str(tmp_path), DATASET, checkpoint=checkpoint_config) - - @pytest.mark.parametrize( - "pds_config,fs_creds,dataset_creds,checkpoint_creds", - [ - ( - {"dataset": DATASET, "credentials": {"cred": "common"}}, - {"cred": "common"}, - {"cred": "common"}, - {"cred": "common"}, - ), - ( - { - "dataset": {"type": DATASET, "credentials": {"ds": "only"}}, - "credentials": {"cred": "common"}, - }, - {"cred": "common"}, - {"ds": "only"}, - {"cred": "common"}, - ), - ( - { - "dataset": DATASET, - "credentials": {"cred": "common"}, - "checkpoint": {"credentials": {"cp": "only"}}, - }, - {"cred": "common"}, - {"cred": "common"}, - {"cp": "only"}, - ), - ( - { - "dataset": {"type": DATASET, "credentials": {"ds": "only"}}, - "checkpoint": {"credentials": {"cp": "only"}}, - }, - {}, - {"ds": "only"}, - {"cp": "only"}, - ), - ( - { - "dataset": {"type": DATASET, "credentials": None}, - "credentials": {"cred": "common"}, - "checkpoint": {"credentials": None}, - }, - {"cred": "common"}, - None, - None, - ), - ], - ) - def test_credentials(self, pds_config, fs_creds, dataset_creds, checkpoint_creds): - """Test correctness of credentials propagation into the dataset and - checkpoint constructors""" - pds = IncrementalDataset(str(Path.cwd()), **pds_config) - assert pds._credentials == fs_creds - assert pds._dataset_config[CREDENTIALS_KEY] == dataset_creds - assert pds._checkpoint_config[CREDENTIALS_KEY] == checkpoint_creds - - @pytest.mark.parametrize( - "comparison_func,expected_partitions", - [ - ( - "tests.partitions.test_incremental_dataset.dummy_gt_func", - {"p03/data.csv", "p04/data.csv"}, - ), - (dummy_gt_func, {"p03/data.csv", "p04/data.csv"}), - ( - "tests.partitions.test_incremental_dataset.dummy_lt_func", - {"p00/data.csv", "p01/data.csv"}, - ), - (dummy_lt_func, {"p00/data.csv", "p01/data.csv"}), - ], - ) - def test_comparison_func(self, comparison_func, expected_partitions, local_csvs): - """Test that specifying a custom function for comparing the checkpoint value - to a partition id results in expected partitions being returned on load""" - checkpoint_config = { - "force_checkpoint": "p02/data.csv", - "comparison_func": comparison_func, - } - pds = IncrementalDataset(str(local_csvs), DATASET, checkpoint=checkpoint_config) - assert pds.load().keys() == expected_partitions - - -BUCKET_NAME = "fake_bucket_name" - - -@pytest.fixture -def mocked_s3_bucket(): - """Create a bucket for testing using moto.""" - with mock_s3(): - conn = boto3.client( - "s3", - aws_access_key_id="fake_access_key", - aws_secret_access_key="fake_secret_key", - ) - conn.create_bucket(Bucket=BUCKET_NAME) - yield conn - - -@pytest.fixture -def mocked_csvs_in_s3(mocked_s3_bucket, partitioned_data_pandas): - prefix = "csvs" - for key, data in partitioned_data_pandas.items(): - mocked_s3_bucket.put_object( - Bucket=BUCKET_NAME, - Key=f"{prefix}/{key}", - Body=data.to_csv(index=False), - ) - return f"s3://{BUCKET_NAME}/{prefix}" - - -class TestPartitionedDatasetS3: - os.environ["AWS_ACCESS_KEY_ID"] = "FAKE_ACCESS_KEY" - os.environ["AWS_SECRET_ACCESS_KEY"] = "FAKE_SECRET_KEY" - - def test_load_and_confirm(self, mocked_csvs_in_s3, partitioned_data_pandas): - """Test the standard flow for loading, confirming and reloading - a IncrementalDataset in S3""" - pds = IncrementalDataset(mocked_csvs_in_s3, DATASET) - assert pds._checkpoint._protocol == "s3" - loaded = pds.load() - assert loaded.keys() == partitioned_data_pandas.keys() - for partition_id, data in loaded.items(): - assert_frame_equal(data, partitioned_data_pandas[partition_id]) - - assert not pds._checkpoint.exists() - assert pds._read_checkpoint() is None - pds.confirm() - assert pds._checkpoint.exists() - assert pds._read_checkpoint() == max(partitioned_data_pandas) - - def test_load_and_confirm_s3a( - self, mocked_csvs_in_s3, partitioned_data_pandas, mocker - ): - s3a_path = f"s3a://{mocked_csvs_in_s3.split('://', 1)[1]}" - pds = IncrementalDataset(s3a_path, DATASET) - assert pds._protocol == "s3a" - assert pds._checkpoint._protocol == "s3" - - mocked_ds = mocker.patch.object(pds, "_dataset_type") - mocked_ds.__name__ = "mocked" - loaded = pds.load() - - assert loaded.keys() == partitioned_data_pandas.keys() - assert not pds._checkpoint.exists() - assert pds._read_checkpoint() is None - pds.confirm() - assert pds._checkpoint.exists() - assert pds._read_checkpoint() == max(partitioned_data_pandas) - - @pytest.mark.parametrize( - "forced_checkpoint,expected_partitions", - [ - ( - "", - { - "p00/data.csv", - "p01/data.csv", - "p02/data.csv", - "p03/data.csv", - "p04/data.csv", - }, - ), - ( - "p00/data.csv", - {"p01/data.csv", "p02/data.csv", "p03/data.csv", "p04/data.csv"}, - ), - ("p03/data.csv", {"p04/data.csv"}), - ], - ) - def test_force_checkpoint_no_checkpoint_file( - self, forced_checkpoint, expected_partitions, mocked_csvs_in_s3 - ): - """Test how forcing checkpoint value affects the available partitions - in S3 if the checkpoint file does not exist""" - pds = IncrementalDataset( - mocked_csvs_in_s3, DATASET, checkpoint=forced_checkpoint - ) - loaded = pds.load() - assert loaded.keys() == expected_partitions - - assert not pds._checkpoint.exists() - pds.confirm() - assert pds._checkpoint.exists() - assert pds._checkpoint.load() == max(expected_partitions) - - @pytest.mark.parametrize( - "forced_checkpoint,expected_partitions", - [ - ( - "", - { - "p00/data.csv", - "p01/data.csv", - "p02/data.csv", - "p03/data.csv", - "p04/data.csv", - }, - ), - ( - "p00/data.csv", - {"p01/data.csv", "p02/data.csv", "p03/data.csv", "p04/data.csv"}, - ), - ("p03/data.csv", {"p04/data.csv"}), - ], - ) - def test_force_checkpoint_checkpoint_file_exists( - self, forced_checkpoint, expected_partitions, mocked_csvs_in_s3 - ): - """Test how forcing checkpoint value affects the available partitions - in S3 if the checkpoint file exists""" - # create checkpoint and assert that it exists - IncrementalDataset(mocked_csvs_in_s3, DATASET).confirm() - checkpoint_path = ( - f"{mocked_csvs_in_s3}/{IncrementalDataset.DEFAULT_CHECKPOINT_FILENAME}" - ) - checkpoint_value = TextDataSet(checkpoint_path).load() - assert checkpoint_value == "p04/data.csv" - - pds = IncrementalDataset( - mocked_csvs_in_s3, DATASET, checkpoint=forced_checkpoint - ) - assert pds._checkpoint.exists() - loaded = pds.load() - assert loaded.keys() == expected_partitions - - @pytest.mark.parametrize( - "forced_checkpoint", ["p04/data.csv", "p10/data.csv", "p100/data.csv"] - ) - def test_force_checkpoint_no_partitions(self, forced_checkpoint, mocked_csvs_in_s3): - """Test that forcing the checkpoint to certain values results in no - partitions returned from S3""" - pds = IncrementalDataset( - mocked_csvs_in_s3, DATASET, checkpoint=forced_checkpoint - ) - loaded = pds.load() - assert not loaded - - assert not pds._checkpoint.exists() - pds.confirm() - # confirming with no partitions available must have no effect - assert not pds._checkpoint.exists() diff --git a/kedro-datasets/tests/partitions/test_partitioned_dataset.py b/kedro-datasets/tests/partitions/test_partitioned_dataset.py deleted file mode 100644 index 4feb79ac4..000000000 --- a/kedro-datasets/tests/partitions/test_partitioned_dataset.py +++ /dev/null @@ -1,540 +0,0 @@ -import logging -import os -import re -from pathlib import Path - -import boto3 -import pandas as pd -import pytest -import s3fs -from kedro.io import DatasetError -from kedro.io.data_catalog import CREDENTIALS_KEY -from moto import mock_s3 -from pandas.util.testing import assert_frame_equal - -from kedro_datasets.pandas import CSVDataset, ParquetDataset -from kedro_datasets.partitions import PartitionedDataset -from kedro_datasets.partitions.partitioned_dataset import KEY_PROPAGATION_WARNING - - -@pytest.fixture -def partitioned_data_pandas(): - keys = ("p1/data1.csv", "p2.csv", "p1/data2.csv", "p3", "_p4") - return { - k: pd.DataFrame({"part": k, "counter": list(range(counter))}) - for counter, k in enumerate(keys, 1) - } - - -@pytest.fixture -def local_csvs(tmp_path, partitioned_data_pandas): - local_dir = Path(str(tmp_path / "csvs")) - local_dir.mkdir() - - for k, data in partitioned_data_pandas.items(): - path = local_dir / k - path.parent.mkdir(parents=True, exist_ok=True) - data.to_csv(str(path), index=False) - return local_dir - - -LOCAL_DATASET_DEFINITION = [ - "pandas.CSVDataset", - "kedro_datasets.pandas.CSVDataset", - CSVDataset, - {"type": "kedro_datasets.pandas.CSVDataset", "save_args": {"index": False}}, - {"type": CSVDataset}, -] - - -class FakeDataset: # pylint: disable=too-few-public-methods - pass - - -class TestPartitionedDatasetLocal: - @pytest.mark.parametrize("dataset", LOCAL_DATASET_DEFINITION) - @pytest.mark.parametrize( - "suffix,expected_num_parts", [("", 5), (".csv", 3), ("p4", 1)] - ) - def test_load( - self, dataset, local_csvs, partitioned_data_pandas, suffix, expected_num_parts - ): - pds = PartitionedDataset(str(local_csvs), dataset, filename_suffix=suffix) - loaded_partitions = pds.load() - - assert len(loaded_partitions.keys()) == expected_num_parts - for partition_id, load_func in loaded_partitions.items(): - df = load_func() - assert_frame_equal(df, partitioned_data_pandas[partition_id + suffix]) - if suffix: - assert not partition_id.endswith(suffix) - - @pytest.mark.parametrize("dataset", LOCAL_DATASET_DEFINITION) - @pytest.mark.parametrize("suffix", ["", ".csv"]) - def test_save(self, dataset, local_csvs, suffix): - pds = PartitionedDataset(str(local_csvs), dataset, filename_suffix=suffix) - original_data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]}) - part_id = "new/data" - pds.save({part_id: original_data}) - - assert (local_csvs / "new" / ("data" + suffix)).is_file() - loaded_partitions = pds.load() - assert part_id in loaded_partitions - reloaded_data = loaded_partitions[part_id]() - assert_frame_equal(reloaded_data, original_data) - - @pytest.mark.parametrize("dataset", LOCAL_DATASET_DEFINITION) - @pytest.mark.parametrize("suffix", ["", ".csv"]) - def test_lazy_save(self, dataset, local_csvs, suffix): - pds = PartitionedDataset(str(local_csvs), dataset, filename_suffix=suffix) - - def original_data(): - return pd.DataFrame({"foo": 42, "bar": ["a", "b", None]}) - - part_id = "new/data" - pds.save({part_id: original_data}) - - assert (local_csvs / "new" / ("data" + suffix)).is_file() - loaded_partitions = pds.load() - assert part_id in loaded_partitions - reloaded_data = loaded_partitions[part_id]() - assert_frame_equal(reloaded_data, original_data()) - - def test_save_invalidates_cache(self, local_csvs, mocker): - """Test that save calls invalidate partition cache""" - pds = PartitionedDataset(str(local_csvs), "pandas.CSVDataset") - mocked_fs_invalidate = mocker.patch.object(pds._filesystem, "invalidate_cache") - first_load = pds.load() - assert pds._partition_cache.currsize == 1 - mocked_fs_invalidate.assert_not_called() - - # save clears cache - data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]}) - new_partition = "new/data.csv" - pds.save({new_partition: data}) - assert pds._partition_cache.currsize == 0 - # it seems that `_filesystem.invalidate_cache` calls itself inside, - # resulting in not one, but 2 mock calls - # hence using `assert_any_call` instead of `assert_called_once_with` - mocked_fs_invalidate.assert_any_call(pds._normalized_path) - - # new load returns new partition too - second_load = pds.load() - assert new_partition not in first_load - assert new_partition in second_load - - @pytest.mark.parametrize("overwrite,expected_num_parts", [(False, 6), (True, 1)]) - def test_overwrite(self, local_csvs, overwrite, expected_num_parts): - pds = PartitionedDataset( - str(local_csvs), "pandas.CSVDataset", overwrite=overwrite - ) - original_data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]}) - part_id = "new/data" - pds.save({part_id: original_data}) - loaded_partitions = pds.load() - - assert part_id in loaded_partitions - assert len(loaded_partitions.keys()) == expected_num_parts - - def test_release_instance_cache(self, local_csvs): - """Test that cache invalidation does not affect other instances""" - ds_a = PartitionedDataset(str(local_csvs), "pandas.CSVDataset") - ds_a.load() - ds_b = PartitionedDataset(str(local_csvs), "pandas.CSVDataset") - ds_b.load() - - assert ds_a._partition_cache.currsize == 1 - assert ds_b._partition_cache.currsize == 1 - - # invalidate cache of the dataset A - ds_a.release() - assert ds_a._partition_cache.currsize == 0 - # cache of the dataset B is unaffected - assert ds_b._partition_cache.currsize == 1 - - @pytest.mark.parametrize("dataset", ["pandas.CSVDataset", "pandas.ParquetDataset"]) - def test_exists(self, local_csvs, dataset): - assert PartitionedDataset(str(local_csvs), dataset).exists() - - empty_folder = local_csvs / "empty" / "folder" - assert not PartitionedDataset(str(empty_folder), dataset).exists() - empty_folder.mkdir(parents=True) - assert not PartitionedDataset(str(empty_folder), dataset).exists() - - @pytest.mark.parametrize("dataset", LOCAL_DATASET_DEFINITION) - def test_release(self, dataset, local_csvs): - partition_to_remove = "p2.csv" - pds = PartitionedDataset(str(local_csvs), dataset) - initial_load = pds.load() - assert partition_to_remove in initial_load - - (local_csvs / partition_to_remove).unlink() - cached_load = pds.load() - assert initial_load.keys() == cached_load.keys() - - pds.release() - load_after_release = pds.load() - assert initial_load.keys() ^ load_after_release.keys() == {partition_to_remove} - - @pytest.mark.parametrize("dataset", LOCAL_DATASET_DEFINITION) - def test_describe(self, dataset): - path = str(Path.cwd()) - pds = PartitionedDataset(path, dataset) - - assert f"path={path}" in str(pds) - assert "dataset_type=CSVDataset" in str(pds) - assert "dataset_config" in str(pds) - - def test_load_args(self, mocker): - fake_partition_name = "fake_partition" - mocked_filesystem = mocker.patch("fsspec.filesystem") - mocked_find = mocked_filesystem.return_value.find - mocked_find.return_value = [fake_partition_name] - - path = str(Path.cwd()) - load_args = {"maxdepth": 42, "withdirs": True} - pds = PartitionedDataset(path, "pandas.CSVDataset", load_args=load_args) - mocker.patch.object(pds, "_path_to_partition", return_value=fake_partition_name) - - assert pds.load().keys() == {fake_partition_name} - mocked_find.assert_called_once_with(path, **load_args) - - @pytest.mark.parametrize( - "credentials,expected_pds_creds,expected_dataset_creds", - [({"cred": "common"}, {"cred": "common"}, {"cred": "common"}), (None, {}, {})], - ) - def test_credentials( - self, mocker, credentials, expected_pds_creds, expected_dataset_creds - ): - mocked_filesystem = mocker.patch("fsspec.filesystem") - path = str(Path.cwd()) - pds = PartitionedDataset(path, "pandas.CSVDataset", credentials=credentials) - - assert mocked_filesystem.call_count == 2 - mocked_filesystem.assert_called_with("file", **expected_pds_creds) - if expected_dataset_creds: - assert pds._dataset_config[CREDENTIALS_KEY] == expected_dataset_creds - else: - assert CREDENTIALS_KEY not in pds._dataset_config - - str_repr = str(pds) - - def _assert_not_in_repr(value): - if isinstance(value, dict): - for k_, v_ in value.items(): - _assert_not_in_repr(k_) - _assert_not_in_repr(v_) - if value is not None: - assert str(value) not in str_repr - - _assert_not_in_repr(credentials) - - def test_fs_args(self, mocker): - fs_args = {"foo": "bar"} - - mocked_filesystem = mocker.patch("fsspec.filesystem") - path = str(Path.cwd()) - pds = PartitionedDataset(path, "pandas.CSVDataset", fs_args=fs_args) - - assert mocked_filesystem.call_count == 2 - mocked_filesystem.assert_called_with("file", **fs_args) - assert pds._dataset_config["fs_args"] == fs_args - - @pytest.mark.parametrize("dataset", ["pandas.ParquetDataset", ParquetDataset]) - def test_invalid_dataset(self, dataset, local_csvs): - pds = PartitionedDataset(str(local_csvs), dataset) - loaded_partitions = pds.load() - - for partition, df_loader in loaded_partitions.items(): - pattern = r"Failed while loading data from data set ParquetDataset(.*)" - with pytest.raises(DatasetError, match=pattern) as exc_info: - df_loader() - error_message = str(exc_info.value) - assert ( - "Either the file is corrupted or this is not a parquet file" - in error_message - ) - assert str(partition) in error_message - - @pytest.mark.parametrize( - "dataset_config,error_pattern", - [ - ("UndefinedDatasetType", "Class 'UndefinedDatasetType' not found"), - ( - "missing.module.UndefinedDatasetType", - r"Class 'missing\.module\.UndefinedDatasetType' not found", - ), - ( - FakeDataset, - r"Dataset type 'tests\.partitions\.test_partitioned_dataset\.FakeDataset' " - r"is invalid\: all data set types must extend 'AbstractDataset'", - ), - ({}, "'type' is missing from dataset catalog configuration"), - ], - ) - def test_invalid_dataset_config(self, dataset_config, error_pattern): - with pytest.raises(DatasetError, match=error_pattern): - PartitionedDataset(str(Path.cwd()), dataset_config) - - @pytest.mark.parametrize( - "dataset_config", - [ - {"type": CSVDataset, "versioned": True}, - {"type": "pandas.CSVDataset", "versioned": True}, - ], - ) - def test_versioned_dataset_not_allowed(self, dataset_config): - pattern = ( - "'PartitionedDataset' does not support versioning of the underlying " - "dataset. Please remove 'versioned' flag from the dataset definition." - ) - with pytest.raises(DatasetError, match=re.escape(pattern)): - PartitionedDataset(str(Path.cwd()), dataset_config) - - def test_no_partitions(self, tmpdir): - pds = PartitionedDataset(str(tmpdir), "pandas.CSVDataset") - - pattern = re.escape(f"No partitions found in '{tmpdir}'") - with pytest.raises(DatasetError, match=pattern): - pds.load() - - @pytest.mark.parametrize( - "pds_config,filepath_arg", - [ - ( - { - "path": str(Path.cwd()), - "dataset": {"type": CSVDataset, "filepath": "fake_path"}, - }, - "filepath", - ), - ( - { - "path": str(Path.cwd()), - "dataset": {"type": CSVDataset, "other_arg": "fake_path"}, - "filepath_arg": "other_arg", - }, - "other_arg", - ), - ], - ) - def test_filepath_arg_warning(self, pds_config, filepath_arg): - pattern = ( - f"'{filepath_arg}' key must not be specified in the dataset definition as it " - f"will be overwritten by partition path" - ) - with pytest.warns(UserWarning, match=re.escape(pattern)): - PartitionedDataset(**pds_config) - - def test_credentials_log_warning(self, caplog): - """Check that the warning is logged if the dataset credentials will overwrite - the top-level ones""" - pds = PartitionedDataset( - path=str(Path.cwd()), - dataset={"type": CSVDataset, "credentials": {"secret": "dataset"}}, - credentials={"secret": "global"}, - ) - log_message = KEY_PROPAGATION_WARNING % { - "keys": "credentials", - "target": "underlying dataset", - } - assert caplog.record_tuples == [("kedro.io.core", logging.WARNING, log_message)] - assert pds._dataset_config["credentials"] == {"secret": "dataset"} - - def test_fs_args_log_warning(self, caplog): - """Check that the warning is logged if the dataset filesystem - arguments will overwrite the top-level ones""" - pds = PartitionedDataset( - path=str(Path.cwd()), - dataset={"type": CSVDataset, "fs_args": {"args": "dataset"}}, - fs_args={"args": "dataset"}, - ) - log_message = KEY_PROPAGATION_WARNING % { - "keys": "filesystem arguments", - "target": "underlying dataset", - } - assert caplog.record_tuples == [("kedro.io.core", logging.WARNING, log_message)] - assert pds._dataset_config["fs_args"] == {"args": "dataset"} - - @pytest.mark.parametrize( - "pds_config,expected_ds_creds,global_creds", - [ - ( - {"dataset": "pandas.CSVDataset", "credentials": {"secret": "global"}}, - {"secret": "global"}, - {"secret": "global"}, - ), - ( - { - "dataset": { - "type": CSVDataset, - "credentials": {"secret": "expected"}, - }, - }, - {"secret": "expected"}, - {}, - ), - ( - { - "dataset": {"type": CSVDataset, "credentials": None}, - "credentials": {"secret": "global"}, - }, - None, - {"secret": "global"}, - ), - ( - { - "dataset": { - "type": CSVDataset, - "credentials": {"secret": "expected"}, - }, - "credentials": {"secret": "global"}, - }, - {"secret": "expected"}, - {"secret": "global"}, - ), - ], - ) - def test_dataset_creds(self, pds_config, expected_ds_creds, global_creds): - """Check that global credentials do not interfere dataset credentials.""" - pds = PartitionedDataset(path=str(Path.cwd()), **pds_config) - assert pds._dataset_config["credentials"] == expected_ds_creds - assert pds._credentials == global_creds - - -BUCKET_NAME = "fake_bucket_name" -S3_DATASET_DEFINITION = [ - "pandas.CSVDataset", - "kedro_datasets.pandas.CSVDataset", - CSVDataset, - {"type": "kedro_datasets.pandas.CSVDataset", "save_args": {"index": False}}, - {"type": CSVDataset}, -] - - -@pytest.fixture -def mocked_s3_bucket(): - """Create a bucket for testing using moto.""" - with mock_s3(): - conn = boto3.client( - "s3", - aws_access_key_id="fake_access_key", - aws_secret_access_key="fake_secret_key", - ) - conn.create_bucket(Bucket=BUCKET_NAME) - yield conn - - -@pytest.fixture -def mocked_csvs_in_s3(mocked_s3_bucket, partitioned_data_pandas): - prefix = "csvs" - for key, data in partitioned_data_pandas.items(): - mocked_s3_bucket.put_object( - Bucket=BUCKET_NAME, - Key=f"{prefix}/{key}", - Body=data.to_csv(index=False), - ) - return f"s3://{BUCKET_NAME}/{prefix}" - - -class TestPartitionedDatasetS3: - os.environ["AWS_ACCESS_KEY_ID"] = "FAKE_ACCESS_KEY" - os.environ["AWS_SECRET_ACCESS_KEY"] = "FAKE_SECRET_KEY" - - @pytest.mark.parametrize("dataset", S3_DATASET_DEFINITION) - def test_load(self, dataset, mocked_csvs_in_s3, partitioned_data_pandas): - pds = PartitionedDataset(mocked_csvs_in_s3, dataset) - loaded_partitions = pds.load() - - assert loaded_partitions.keys() == partitioned_data_pandas.keys() - for partition_id, load_func in loaded_partitions.items(): - df = load_func() - assert_frame_equal(df, partitioned_data_pandas[partition_id]) - - def test_load_s3a(self, mocked_csvs_in_s3, partitioned_data_pandas, mocker): - path = mocked_csvs_in_s3.split("://", 1)[1] - s3a_path = f"s3a://{path}" - # any type is fine as long as it passes isinstance check - # since _dataset_type is mocked later anyways - pds = PartitionedDataset(s3a_path, "pandas.CSVDataset") - assert pds._protocol == "s3a" - - mocked_ds = mocker.patch.object(pds, "_dataset_type") - mocked_ds.__name__ = "mocked" - loaded_partitions = pds.load() - - assert loaded_partitions.keys() == partitioned_data_pandas.keys() - assert mocked_ds.call_count == len(loaded_partitions) - expected = [ - mocker.call(filepath=f"{s3a_path}/{partition_id}") - for partition_id in loaded_partitions - ] - mocked_ds.assert_has_calls(expected, any_order=True) - - @pytest.mark.parametrize("dataset", S3_DATASET_DEFINITION) - def test_save(self, dataset, mocked_csvs_in_s3): - pds = PartitionedDataset(mocked_csvs_in_s3, dataset) - original_data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]}) - part_id = "new/data.csv" - pds.save({part_id: original_data}) - - s3 = s3fs.S3FileSystem() - assert s3.exists("/".join([mocked_csvs_in_s3, part_id])) - - loaded_partitions = pds.load() - assert part_id in loaded_partitions - reloaded_data = loaded_partitions[part_id]() - assert_frame_equal(reloaded_data, original_data) - - def test_save_s3a(self, mocked_csvs_in_s3, mocker): - """Test that save works in case of s3a protocol""" - path = mocked_csvs_in_s3.split("://", 1)[1] - s3a_path = f"s3a://{path}" - # any type is fine as long as it passes isinstance check - # since _dataset_type is mocked later anyways - pds = PartitionedDataset(s3a_path, "pandas.CSVDataset", filename_suffix=".csv") - assert pds._protocol == "s3a" - - mocked_ds = mocker.patch.object(pds, "_dataset_type") - mocked_ds.__name__ = "mocked" - new_partition = "new/data" - data = "data" - - pds.save({new_partition: data}) - mocked_ds.assert_called_once_with(filepath=f"{s3a_path}/{new_partition}.csv") - mocked_ds.return_value.save.assert_called_once_with(data) - - @pytest.mark.parametrize("dataset", ["pandas.CSVDataset", "pandas.HDFDataset"]) - def test_exists(self, dataset, mocked_csvs_in_s3): - assert PartitionedDataset(mocked_csvs_in_s3, dataset).exists() - - empty_folder = "/".join([mocked_csvs_in_s3, "empty", "folder"]) - assert not PartitionedDataset(empty_folder, dataset).exists() - - s3fs.S3FileSystem().mkdir(empty_folder) - assert not PartitionedDataset(empty_folder, dataset).exists() - - @pytest.mark.parametrize("dataset", S3_DATASET_DEFINITION) - def test_release(self, dataset, mocked_csvs_in_s3): - partition_to_remove = "p2.csv" - pds = PartitionedDataset(mocked_csvs_in_s3, dataset) - initial_load = pds.load() - assert partition_to_remove in initial_load - - s3 = s3fs.S3FileSystem() - s3.rm("/".join([mocked_csvs_in_s3, partition_to_remove])) - cached_load = pds.load() - assert initial_load.keys() == cached_load.keys() - - pds.release() - load_after_release = pds.load() - assert initial_load.keys() ^ load_after_release.keys() == {partition_to_remove} - - @pytest.mark.parametrize("dataset", S3_DATASET_DEFINITION) - def test_describe(self, dataset): - path = f"s3://{BUCKET_NAME}/foo/bar" - pds = PartitionedDataset(path, dataset) - - assert f"path={path}" in str(pds) - assert "dataset_type=CSVDataset" in str(pds) - assert "dataset_config" in str(pds) From 5030e355e149c469c9bdcc9f7e260f131949e7ae Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Thu, 5 Oct 2023 17:48:21 +0100 Subject: [PATCH 8/9] Clean up Signed-off-by: Merel Theisen --- kedro-datasets/docs/source/kedro_datasets.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/kedro-datasets/docs/source/kedro_datasets.rst b/kedro-datasets/docs/source/kedro_datasets.rst index 67f87e0e3..d8db36ee0 100644 --- a/kedro-datasets/docs/source/kedro_datasets.rst +++ b/kedro-datasets/docs/source/kedro_datasets.rst @@ -59,8 +59,6 @@ kedro_datasets kedro_datasets.pandas.SQLTableDataset kedro_datasets.pandas.XMLDataSet kedro_datasets.pandas.XMLDataset - kedro_datasets.partitions.IncrementalDataset - kedro_datasets.partitions.PartitionedDataset kedro_datasets.pickle.PickleDataSet kedro_datasets.pickle.PickleDataset kedro_datasets.pillow.ImageDataSet From 43b04f66fe960cfaf3a50acbdf035ccefa3c52c2 Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Fri, 6 Oct 2023 14:01:05 +0100 Subject: [PATCH 9/9] Add comment to explain why tensorflow tests are run separately Signed-off-by: Merel Theisen --- Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/Makefile b/Makefile index be475b09b..66796a3de 100644 --- a/Makefile +++ b/Makefile @@ -19,6 +19,7 @@ lint: test: cd $(plugin) && pytest tests --cov-config pyproject.toml --numprocesses 4 --dist loadfile +# Run test_tensorflow_model_dataset separately, because these tests are flaky when run as part of the full test-suite dataset-tests: cd kedro-datasets && pytest tests --cov-config pyproject.toml --numprocesses 4 --dist loadfile --ignore tests/tensorflow cd kedro-datasets && pytest tests/tensorflow/test_tensorflow_model_dataset.py --no-cov