diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md old mode 100755 new mode 100644 index 9c6deef45..3b51df818 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -1,5 +1,17 @@ -# Upcoming Release: +# Upcoming Release 1.1.0: + + +## Major features and improvements: + +* Added the following new datasets: + +| Type | Description | Location | +| ------------------------------------ | -------------------------------------------------------------------------- | ----------------------------- | +| `polars.CSVDataSet` | A `CSVDataSet` backed by [polars](https://www.pola.rs/), a lighting fast dataframe package built entirely using Rust. | `kedro_datasets.polars` | + +## Bug fixes and other changes + # Release 1.0.2: @@ -13,6 +25,7 @@ ## Bug fixes and other changes * Fixed doc string formatting in `VideoDataSet` causing the documentation builds to fail. + # Release 1.0.0: First official release of Kedro-Datasets. diff --git a/kedro-datasets/kedro_datasets/polars/__init__.py b/kedro-datasets/kedro_datasets/polars/__init__.py new file mode 100644 index 000000000..34d39c985 --- /dev/null +++ b/kedro-datasets/kedro_datasets/polars/__init__.py @@ -0,0 +1,8 @@ +"""``AbstractDataSet`` implementations that produce pandas DataFrames.""" + +__all__ = ["CSVDataSet"] + +from contextlib import suppress + +with suppress(ImportError): + from .csv_dataset import CSVDataSet diff --git a/kedro-datasets/kedro_datasets/polars/csv_dataset.py b/kedro-datasets/kedro_datasets/polars/csv_dataset.py new file mode 100644 index 000000000..60a0d456a --- /dev/null +++ b/kedro-datasets/kedro_datasets/polars/csv_dataset.py @@ -0,0 +1,191 @@ +"""``CSVDataSet`` loads/saves data from/to a CSV file using an underlying +filesystem (e.g.: local, S3, GCS). It uses polars to handle the CSV file. +""" +import logging +from copy import deepcopy +from io import BytesIO +from pathlib import PurePosixPath +from typing import Any, Dict + +import fsspec +import polars as pl +from kedro.io.core import ( + PROTOCOL_DELIMITER, + AbstractVersionedDataSet, + DataSetError, + Version, + get_filepath_str, + get_protocol_and_path, +) + +logger = logging.getLogger(__name__) + + +class CSVDataSet(AbstractVersionedDataSet[pl.DataFrame, pl.DataFrame]): + """``CSVDataSet`` loads/saves data from/to a CSV file using an underlying + filesystem (e.g.: local, S3, GCS). It uses polars to handle the CSV file. + + Example adding a catalog entry with + `YAML API + `_: + + .. code-block:: yaml + + >>> cars: + >>> type: polars.CSVDataSet + >>> filepath: data/01_raw/company/cars.csv + >>> load_args: + >>> sep: "," + >>> parse_dates: False + >>> save_args: + >>> has_header: False + null_value: "somenullstring" + >>> + >>> motorbikes: + >>> type: polars.CSVDataSet + >>> filepath: s3://your_bucket/data/02_intermediate/company/motorbikes.csv + >>> credentials: dev_s3 + + Example using Python API: + :: + + >>> from kedro_datasets.polars import CSVDataSet + >>> import polars as pl + >>> + >>> data = pl.DataFrame({'col1': [1, 2], 'col2': [4, 5], + >>> 'col3': [5, 6]}) + >>> + >>> data_set = CSVDataSet(filepath="test.csv") + >>> data_set.save(data) + >>> reloaded = data_set.load() + >>> assert data.frame_equal(reloaded) + + """ + + DEFAULT_LOAD_ARGS = {"rechunk": True} # type: Dict[str, Any] + DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] + + # pylint: disable=too-many-arguments + def __init__( + self, + filepath: str, + load_args: Dict[str, Any] = None, + save_args: Dict[str, Any] = None, + version: Version = None, + credentials: Dict[str, Any] = None, + fs_args: Dict[str, Any] = None, + ) -> None: + """Creates a new instance of ``CSVDataSet`` pointing to a concrete CSV file + on a specific filesystem. + + Args: + filepath: Filepath in POSIX format to a CSV file prefixed with a protocol + `s3://`. + If prefix is not provided, `file` protocol (local filesystem) + will be used. + The prefix should be any protocol supported by ``fsspec``. + Note: `http(s)` doesn't support versioning. + load_args: Polars options for loading CSV files. + Here you can find all available arguments: + https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_csv.html#polars.read_csv + All defaults are preserved, but we explicity use `rechunk=True` for `seaborn` + compability. + save_args: Polars options for saving CSV files. + Here you can find all available arguments: + https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_csv.html + All defaults are preserved. + version: If specified, should be an instance of + ``kedro.io.core.Version``. If its ``load`` attribute is + None, the latest version will be loaded. If its ``save`` + attribute is None, save version will be autogenerated. + credentials: Credentials required to get access to the underlying filesystem. + E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. + fs_args: Extra arguments to pass into underlying filesystem class constructor + (e.g. `{"project": "my-project"}` for ``GCSFileSystem``). + """ + _fs_args = deepcopy(fs_args) or {} + _credentials = deepcopy(credentials) or {} + + protocol, path = get_protocol_and_path(filepath, version) + if protocol == "file": + _fs_args.setdefault("auto_mkdir", True) + + self._protocol = protocol + self._storage_options = {**_credentials, **_fs_args} + self._fs = fsspec.filesystem(self._protocol, **self._storage_options) + + super().__init__( + filepath=PurePosixPath(path), + version=version, + exists_function=self._fs.exists, + glob_function=self._fs.glob, + ) + + # Handle default load and save arguments + self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) + if load_args is not None: + self._load_args.update(load_args) + self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) + if save_args is not None: + self._save_args.update(save_args) + + if "storage_options" in self._save_args or "storage_options" in self._load_args: + logger.warning( + "Dropping 'storage_options' for %s, " + "please specify them under 'fs_args' or 'credentials'.", + self._filepath, + ) + self._save_args.pop("storage_options", None) + self._load_args.pop("storage_options", None) + + def _describe(self) -> Dict[str, Any]: + return { + "filepath": self._filepath, + "protocol": self._protocol, + "load_args": self._load_args, + "save_args": self._save_args, + "version": self._version, + } + + def _load(self) -> pl.DataFrame: + load_path = str(self._get_load_path()) + if self._protocol == "file": + # file:// protocol seems to misbehave on Windows + # (), + # so we don't join that back to the filepath; + # storage_options also don't work with local paths + return pl.read_csv(load_path, **self._load_args) + + load_path = f"{self._protocol}{PROTOCOL_DELIMITER}{load_path}" + return pl.read_csv( + load_path, storage_options=self._storage_options, **self._load_args + ) + + def _save(self, data: pl.DataFrame) -> None: + save_path = get_filepath_str(self._get_save_path(), self._protocol) + + buf = BytesIO() + data.write_csv(file=buf, **self._save_args) + + with self._fs.open(save_path, mode="wb") as fs_file: + fs_file.write(buf.getvalue()) + + self._invalidate_cache() + + def _exists(self) -> bool: + try: + load_path = get_filepath_str(self._get_load_path(), self._protocol) + except DataSetError: + return False + + return self._fs.exists(load_path) + + def _release(self) -> None: + super()._release() + self._invalidate_cache() + + def _invalidate_cache(self) -> None: + """Invalidate underlying filesystem caches.""" + filepath = get_filepath_str(self._filepath, self._protocol) + self._fs.invalidate_cache(filepath) diff --git a/kedro-datasets/setup.py b/kedro-datasets/setup.py index f75d3cad1..9effe1fca 100644 --- a/kedro-datasets/setup.py +++ b/kedro-datasets/setup.py @@ -13,6 +13,7 @@ SPARK = "pyspark>=2.2, <4.0" HDFS = "hdfs>=2.5.8, <3.0" S3FS = "s3fs>=0.3.0, <0.5" +POLARS = "polars~=0.15.16" with open("requirements.txt", "r", encoding="utf-8") as f: install_requires = [x.strip() for x in f if x.strip()] @@ -62,6 +63,7 @@ def _collect_requirements(requires): "pandas.GenericDataSet": [PANDAS], } pillow_require = {"pillow.ImageDataSet": ["Pillow~=9.0"]} +polars_require = {"polars.CSVDataSet": [POLARS],} video_require = { "video.VideoDataSet": ["opencv-python~=4.5.5.64"] } @@ -107,6 +109,7 @@ def _collect_requirements(requires): "networkx": _collect_requirements(networkx_require), "pandas": _collect_requirements(pandas_require), "pillow": _collect_requirements(pillow_require), + "polars": _collect_requirements(polars_require), "video": _collect_requirements(video_require), "plotly": _collect_requirements(plotly_require), "redis": _collect_requirements(redis_require), @@ -123,6 +126,7 @@ def _collect_requirements(requires): **networkx_require, **pandas_require, **pillow_require, + **polars_require, **video_require, **plotly_require, **spark_require, diff --git a/kedro-datasets/test_requirements.txt b/kedro-datasets/test_requirements.txt index d0472d429..8dec3619b 100644 --- a/kedro-datasets/test_requirements.txt +++ b/kedro-datasets/test_requirements.txt @@ -33,6 +33,7 @@ pandas-gbq>=0.12.0, <0.18.0 pandas~=1.3 # 1.3 for read_xml/to_xml Pillow~=9.0 plotly>=4.8.0, <6.0 +polars~=0.15.13 pre-commit>=2.9.2, <3.0 # The hook `mypy` requires pre-commit version 2.9.2. psutil==5.8.0 pyarrow>=1.0, <7.0 diff --git a/kedro-datasets/tests/polars/__init__.py b/kedro-datasets/tests/polars/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/tests/polars/test_csv_dataset.py b/kedro-datasets/tests/polars/test_csv_dataset.py new file mode 100644 index 000000000..8b05a2025 --- /dev/null +++ b/kedro-datasets/tests/polars/test_csv_dataset.py @@ -0,0 +1,376 @@ +import os +import sys +from pathlib import Path, PurePosixPath +from time import sleep + +import boto3 +import polars as pl +import pytest +from adlfs import AzureBlobFileSystem +from fsspec.implementations.http import HTTPFileSystem +from fsspec.implementations.local import LocalFileSystem +from gcsfs import GCSFileSystem +from kedro.io import DataSetError +from kedro.io.core import PROTOCOL_DELIMITER, Version, generate_timestamp +from moto import mock_s3 +from polars.testing import assert_frame_equal +from s3fs.core import S3FileSystem + +from kedro_datasets.polars import CSVDataSet + +BUCKET_NAME = "test_bucket" +FILE_NAME = "test.csv" + + +@pytest.fixture +def filepath_csv(tmp_path): + return (tmp_path / "test.csv").as_posix() + + +@pytest.fixture +def csv_data_set(filepath_csv, load_args, save_args, fs_args): + return CSVDataSet( + filepath=filepath_csv, load_args=load_args, save_args=save_args, fs_args=fs_args + ) + + +@pytest.fixture +def versioned_csv_data_set(filepath_csv, load_version, save_version): + return CSVDataSet( + filepath=filepath_csv, version=Version(load_version, save_version) + ) + + +@pytest.fixture +def dummy_dataframe(): + return pl.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) + + +@pytest.fixture +def partitioned_data_polars(): + return { + f"p{counter:02d}/data.csv": pl.DataFrame( + {"part": counter, "col": list(range(counter + 1))} + ) + for counter in range(5) + } + + +@pytest.fixture +def mocked_s3_bucket(): + """Create a bucket for testing using moto.""" + with mock_s3(): + conn = boto3.client( + "s3", + aws_access_key_id="fake_access_key", + aws_secret_access_key="fake_secret_key", + ) + conn.create_bucket(Bucket=BUCKET_NAME) + yield conn + + +@pytest.fixture +def mocked_dataframe(): + df = pl.DataFrame({"dummy": ["dummy"]}) + return df + + +@pytest.fixture +def mocked_csv_in_s3(mocked_s3_bucket, mocked_dataframe: pl.DataFrame): + + binarycsv = mocked_dataframe.write_csv()[:-1] + + mocked_s3_bucket.put_object( + Bucket=BUCKET_NAME, + Key=FILE_NAME, + Body=binarycsv, + ) + + return f"s3://{BUCKET_NAME}/{FILE_NAME}" + + +class TestCSVDataSet: + def test_save_and_load(self, csv_data_set, dummy_dataframe): + """Test saving and reloading the data set.""" + csv_data_set.save(dummy_dataframe) + reloaded = csv_data_set.load() + assert_frame_equal(dummy_dataframe, reloaded) + + def test_exists(self, csv_data_set, dummy_dataframe): + """Test `exists` method invocation for both existing and + nonexistent data set.""" + assert not csv_data_set.exists() + csv_data_set.save(dummy_dataframe) + assert csv_data_set.exists() + + @pytest.mark.parametrize( + "load_args", [{"k1": "v1", "index": "value"}], indirect=True + ) + def test_load_extra_params(self, csv_data_set, load_args): + """Test overriding the default load arguments.""" + for key, value in load_args.items(): + assert csv_data_set._load_args[key] == value + + @pytest.mark.parametrize( + "save_args", [{"k1": "v1", "index": "value"}], indirect=True + ) + def test_save_extra_params(self, csv_data_set, save_args): + """Test overriding the default save arguments.""" + for key, value in save_args.items(): + assert csv_data_set._save_args[key] == value + + @pytest.mark.parametrize( + "load_args,save_args", + [ + ({"storage_options": {"a": "b"}}, {}), + ({}, {"storage_options": {"a": "b"}}), + ({"storage_options": {"a": "b"}}, {"storage_options": {"x": "y"}}), + ], + ) + def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): + filepath = str(tmp_path / "test.csv") + + ds = CSVDataSet(filepath=filepath, load_args=load_args, save_args=save_args) + + records = [r for r in caplog.records if r.levelname == "WARNING"] + expected_log_message = ( + f"Dropping 'storage_options' for {filepath}, " + f"please specify them under 'fs_args' or 'credentials'." + ) + assert records[0].getMessage() == expected_log_message + assert "storage_options" not in ds._save_args + assert "storage_options" not in ds._load_args + + def test_load_missing_file(self, csv_data_set): + """Check the error when trying to load missing file.""" + pattern = r"Failed while loading data from data set CSVDataSet\(.*\)" + with pytest.raises(DataSetError, match=pattern): + csv_data_set.load() + + @pytest.mark.parametrize( + "filepath,instance_type,credentials", + [ + ("s3://bucket/file.csv", S3FileSystem, {}), + ("file:///tmp/test.csv", LocalFileSystem, {}), + ("/tmp/test.csv", LocalFileSystem, {}), + ("gcs://bucket/file.csv", GCSFileSystem, {}), + ("https://example.com/file.csv", HTTPFileSystem, {}), + ( + "abfs://bucket/file.csv", + AzureBlobFileSystem, + {"account_name": "test", "account_key": "test"}, + ), + ], + ) + def test_protocol_usage(self, filepath, instance_type, credentials): + data_set = CSVDataSet(filepath=filepath, credentials=credentials) + assert isinstance(data_set._fs, instance_type) + + path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] + + assert str(data_set._filepath) == path + assert isinstance(data_set._filepath, PurePosixPath) + + def test_catalog_release(self, mocker): + fs_mock = mocker.patch("fsspec.filesystem").return_value + filepath = "test.csv" + data_set = CSVDataSet(filepath=filepath) + assert data_set._version_cache.currsize == 0 # no cache if unversioned + data_set.release() + fs_mock.invalidate_cache.assert_called_once_with(filepath) + assert data_set._version_cache.currsize == 0 + + +class TestCSVDataSetVersioned: + def test_version_str_repr(self, load_version, save_version): + """Test that version is in string representation of the class instance + when applicable.""" + filepath = "test.csv" + ds = CSVDataSet(filepath=filepath) + ds_versioned = CSVDataSet( + filepath=filepath, version=Version(load_version, save_version) + ) + assert filepath in str(ds) + assert "version" not in str(ds) + + assert filepath in str(ds_versioned) + ver_str = f"version=Version(load={load_version}, save='{save_version}')" + assert ver_str in str(ds_versioned) + assert "CSVDataSet" in str(ds_versioned) + assert "CSVDataSet" in str(ds) + assert "protocol" in str(ds_versioned) + assert "protocol" in str(ds) + # Default save_args + assert "load_args={'rechunk': True}" in str(ds) + assert "load_args={'rechunk': True}" in str(ds_versioned) + + def test_save_and_load(self, versioned_csv_data_set, dummy_dataframe): + """Test that saved and reloaded data matches the original one for + the versioned data set.""" + versioned_csv_data_set.save(dummy_dataframe) + reloaded_df = versioned_csv_data_set.load() + assert_frame_equal(dummy_dataframe, reloaded_df) + + def test_multiple_loads( + self, versioned_csv_data_set, dummy_dataframe, filepath_csv + ): + """Test that if a new version is created mid-run, by an + external system, it won't be loaded in the current run.""" + versioned_csv_data_set.save(dummy_dataframe) + versioned_csv_data_set.load() + v1 = versioned_csv_data_set.resolve_load_version() + + sleep(0.5) + # force-drop a newer version into the same location + v_new = generate_timestamp() + CSVDataSet(filepath=filepath_csv, version=Version(v_new, v_new)).save( + dummy_dataframe + ) + + versioned_csv_data_set.load() + v2 = versioned_csv_data_set.resolve_load_version() + + assert v2 == v1 # v2 should not be v_new! + ds_new = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) + assert ( + ds_new.resolve_load_version() == v_new + ) # new version is discoverable by a new instance + + def test_multiple_saves(self, dummy_dataframe, filepath_csv): + """Test multiple cycles of save followed by load for the same dataset""" + ds_versioned = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) + + # first save + ds_versioned.save(dummy_dataframe) + first_save_version = ds_versioned.resolve_save_version() + first_load_version = ds_versioned.resolve_load_version() + assert first_load_version == first_save_version + + # second save + sleep(0.5) + ds_versioned.save(dummy_dataframe) + second_save_version = ds_versioned.resolve_save_version() + second_load_version = ds_versioned.resolve_load_version() + assert second_load_version == second_save_version + assert second_load_version > first_load_version + + # another dataset + ds_new = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) + assert ds_new.resolve_load_version() == second_load_version + + def test_release_instance_cache(self, dummy_dataframe, filepath_csv): + """Test that cache invalidation does not affect other instances""" + ds_a = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) + assert ds_a._version_cache.currsize == 0 + ds_a.save(dummy_dataframe) # create a version + assert ds_a._version_cache.currsize == 2 + + ds_b = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) + assert ds_b._version_cache.currsize == 0 + ds_b.resolve_save_version() + assert ds_b._version_cache.currsize == 1 + ds_b.resolve_load_version() + assert ds_b._version_cache.currsize == 2 + + ds_a.release() + + # dataset A cache is cleared + assert ds_a._version_cache.currsize == 0 + + # dataset B cache is unaffected + assert ds_b._version_cache.currsize == 2 + + def test_no_versions(self, versioned_csv_data_set): + """Check the error if no versions are available for load.""" + pattern = r"Did not find any versions for CSVDataSet\(.+\)" + with pytest.raises(DataSetError, match=pattern): + versioned_csv_data_set.load() + + def test_exists(self, versioned_csv_data_set, dummy_dataframe): + """Test `exists` method invocation for versioned data set.""" + assert not versioned_csv_data_set.exists() + versioned_csv_data_set.save(dummy_dataframe) + assert versioned_csv_data_set.exists() + + def test_prevent_overwrite(self, versioned_csv_data_set, dummy_dataframe): + """Check the error when attempting to override the data set if the + corresponding CSV file for a given save version already exists.""" + versioned_csv_data_set.save(dummy_dataframe) + pattern = ( + r"Save path \'.+\' for CSVDataSet\(.+\) must " + r"not exist if versioning is enabled\." + ) + with pytest.raises(DataSetError, match=pattern): + versioned_csv_data_set.save(dummy_dataframe) + + @pytest.mark.parametrize( + "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True + ) + @pytest.mark.parametrize( + "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True + ) + def test_save_version_warning( + self, versioned_csv_data_set, load_version, save_version, dummy_dataframe + ): + """Check the warning when saving to the path that differs from + the subsequent load path.""" + pattern = ( + rf"Save version '{save_version}' did not match load version " + rf"'{load_version}' for CSVDataSet\(.+\)" + ) + with pytest.warns(UserWarning, match=pattern): + versioned_csv_data_set.save(dummy_dataframe) + + def test_http_filesystem_no_versioning(self): + pattern = r"HTTP\(s\) DataSet doesn't support versioning\." + + with pytest.raises(DataSetError, match=pattern): + CSVDataSet( + filepath="https://example.com/file.csv", version=Version(None, None) + ) + + def test_versioning_existing_dataset( + self, csv_data_set, versioned_csv_data_set, dummy_dataframe + ): + """Check the error when attempting to save a versioned dataset on top of an + already existing (non-versioned) dataset.""" + csv_data_set.save(dummy_dataframe) + assert csv_data_set.exists() + assert csv_data_set._filepath == versioned_csv_data_set._filepath + pattern = ( + f"(?=.*file with the same name already exists in the directory)" + f"(?=.*{versioned_csv_data_set._filepath.parent.as_posix()})" + ) + with pytest.raises(DataSetError, match=pattern): + versioned_csv_data_set.save(dummy_dataframe) + + # Remove non-versioned dataset and try again + Path(csv_data_set._filepath.as_posix()).unlink() + versioned_csv_data_set.save(dummy_dataframe) + assert versioned_csv_data_set.exists() + + +class TestCSVDataSetS3: + os.environ["AWS_ACCESS_KEY_ID"] = "FAKE_ACCESS_KEY" + os.environ["AWS_SECRET_ACCESS_KEY"] = "FAKE_SECRET_KEY" + + def test_load_and_confirm(self, mocker, mocked_csv_in_s3, mocked_dataframe): + """Test the standard flow for loading, confirming and reloading a + IncrementalDataSet in S3 + + Unmodified Test fails in Python >= 3.10 if executed after test_protocol_usage + (any implementation using S3FileSystem). Likely to be a bug with moto (tested + with moto==4.0.8, moto==3.0.4) -- see #67 + """ + df = CSVDataSet(mocked_csv_in_s3) + assert df._protocol == "s3" + # if Python >= 3.10, modify test procedure (see #67) + if sys.version_info[1] >= 10: + read_patch = mocker.patch("polars.read_csv", return_value=mocked_dataframe) + df.load() + read_patch.assert_called_once_with( + mocked_csv_in_s3, storage_options={}, rechunk=True + ) + else: + loaded = df.load() + assert_frame_equal(loaded, mocked_dataframe)