Skip to content

Commit

Permalink
[KED-862] Add versioning to Azure's CSVBlobDataSet (kedro-org#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
mzjp2 authored and 921kiyo committed Oct 7, 2019
1 parent d404fe3 commit 0c5d98e
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 18 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
* Added a `--open` flag to `kedro build-docs` that opens the documentation on build.
* Updated ``Pipeline`` representation to include name and number of nodes, also making it readable as a context property.
* `kedro.contrib.io.pyspark.SparkDataSet` now supports versioning.
* `kedro.contrib.io.azure.CSVBlobDataSet` now supports versioning.

## Breaking changes to the API
* `KedroContext.run()` no longer accepts `catalog` and `pipeline` arguments.
Expand Down
1 change: 1 addition & 0 deletions docs/source/04_user_guide/08_advanced_io.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,5 @@ Currently the following datasets support versioning:
- `ExcelLocalDataSet`
- `kedro.contrib.io.feather.FeatherLocalDataSet`
- `kedro.contrib.io.parquet.ParquetS3DataSet`
- `kedro.contrib.io.azure.CSVBlobDataSet`
- `kedro.contrib.io.pyspark.SparkDataSet`
86 changes: 70 additions & 16 deletions kedro/contrib/io/azure/csv_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,20 @@
""" ``AbstractDataSet`` implementation to access CSV files directly from
Microsoft's Azure blob storage.
"""
import copy
import io
from typing import Any, Dict, Optional
from functools import partial
from pathlib import PurePosixPath
from typing import Any, Dict, List, Optional

import pandas as pd
from azure.storage.blob import BlockBlobService

from kedro.contrib.io import DefaultArgumentsMixIn
from kedro.io import AbstractDataSet
from kedro.io import AbstractVersionedDataSet, DataSetError, Version


class CSVBlobDataSet(DefaultArgumentsMixIn, AbstractDataSet):
class CSVBlobDataSet(DefaultArgumentsMixIn, AbstractVersionedDataSet):
"""``CSVBlobDataSet`` loads and saves csv files in Microsoft's Azure
blob storage. It uses azure storage SDK to read and write in azure and
pandas to handle the csv file locally.
Expand Down Expand Up @@ -72,6 +75,7 @@ def _describe(self) -> Dict[str, Any]:
blob_from_text_args=self._blob_from_text_args,
load_args=self._load_args,
save_args=self._save_args,
version=self._version,
)

# pylint: disable=too-many-arguments
Expand All @@ -84,6 +88,7 @@ def __init__(
blob_from_text_args: Optional[Dict[str, Any]] = None,
load_args: Optional[Dict[str, Any]] = None,
save_args: Optional[Dict[str, Any]] = None,
version: Version = None,
) -> None:
"""Creates a new instance of ``CSVBlobDataSet`` pointing to a
concrete csv file on Azure blob storage.
Expand All @@ -95,10 +100,10 @@ def __init__(
``account_key`` or ``sas_token``)to access the azure blob
blob_to_text_args: Any additional arguments to pass to azure's
``get_blob_to_text`` method:
https://docs.microsoft.com/en-us/python/api/azure.storage.blob.baseblobservice.baseblobservice?view=azure-python#get-blob-to-text
https://docs.microsoft.com/en-us/python/api/azure-storage-blob/azure.storage.blob.baseblobservice.baseblobservice?view=azure-python#get-blob-to-text
blob_from_text_args: Any additional arguments to pass to azure's
``create_blob_from_text`` method:
https://docs.microsoft.com/en-us/python/api/azure.storage.blob.blockblobservice.blockblobservice?view=azure-python#create-blob-from-text
https://docs.microsoft.com/en-us/python/api/azure-storage-blob/azure.storage.blob.baseblobservice.baseblobservice?view=azure-python#get-blob-to-text
load_args: Pandas options for loading csv files.
Here you can find all available arguments:
https://pandas.pydata.org/pandas-docs/stable/generated/pandas.read_csv.html
Expand All @@ -107,30 +112,79 @@ def __init__(
Here you can find all available arguments:
https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.to_csv.html
All defaults are preserved, but "index", which is set to False.
version: If specified, should be an instance of
``kedro.io.core.Version``. If its ``load`` attribute is
None, the latest version will be loaded. If its ``save``
attribute is None, save version will be autogenerated.
"""
self._filepath = filepath
_credentials = copy.deepcopy(credentials)
_blob_service = BlockBlobService(**_credentials)
glob_function = partial(
_glob,
blob_service=_blob_service,
filepath=filepath,
container_name=container_name,
)
exists_function = partial(
_exists_blob, blob_service=_blob_service, container_name=container_name
)

super().__init__(
load_args=load_args,
save_args=save_args,
filepath=PurePosixPath(filepath),
version=version,
exists_function=exists_function,
glob_function=glob_function,
)

self._blob_to_text_args = copy.deepcopy(blob_to_text_args) or {}
self._blob_from_text_args = copy.deepcopy(blob_from_text_args) or {}

self._container_name = container_name
self._credentials = credentials if credentials else {}
self._blob_to_text_args = blob_to_text_args if blob_to_text_args else {}
self._blob_from_text_args = blob_from_text_args if blob_from_text_args else {}
super().__init__(load_args, save_args)
self._credentials = _credentials
self._blob_service = _blob_service

def _load(self) -> pd.DataFrame:
blob_service = BlockBlobService(**self._credentials)
blob = blob_service.get_blob_to_text(
load_path = str(self._get_load_path())
blob = self._blob_service.get_blob_to_text(
container_name=self._container_name,
blob_name=self._filepath,
blob_name=load_path,
**self._blob_to_text_args
)
csv_content = io.StringIO(blob.content)
return pd.read_csv(csv_content, **self._load_args)

def _save(self, data: pd.DataFrame) -> None:
blob_service = BlockBlobService(**self._credentials)
blob_service.create_blob_from_text(
save_path = self._get_save_path()

self._blob_service.create_blob_from_text(
container_name=self._container_name,
blob_name=self._filepath,
blob_name=str(save_path),
text=data.to_csv(**self._save_args),
**self._blob_from_text_args
)

load_path = self._get_load_path()
self._check_paths_consistency(load_path, save_path)

def _exists(self) -> bool:
try:
load_path = str(self._get_load_path())
except DataSetError:
return False
return _exists_blob(load_path, self._blob_service, self._container_name)


def _exists_blob(
filepath: str, blob_service: BlockBlobService, container_name: str
) -> bool:
return blob_service.exists(container_name, blob_name=filepath)


def _glob(
pattern: str, blob_service: BlockBlobService, container_name: str, filepath: str
) -> List[str]:
blob_paths = blob_service.list_blob_names(container_name, prefix=filepath)
return [path for path in blob_paths if PurePosixPath(path).match(pattern)]
192 changes: 190 additions & 2 deletions tests/contrib/io/azure/test_csv_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,33 @@

# pylint: disable=unused-argument

import io
from pathlib import PurePosixPath
from unittest.mock import patch

import pandas as pd
import pytest
from pandas.util.testing import assert_frame_equal

from kedro.contrib.io.azure import CSVBlobDataSet
from kedro.io import DataSetError
from kedro.io import DataSetError, Version
from kedro.io.core import generate_timestamp

TEST_FILE_NAME = "test.csv"
TEST_CONTAINER_NAME = "test_bucket"
TEST_CREDENTIALS = {"account_name": "ACCOUNT_NAME", "account_key": "ACCOUNT_KEY"}


@pytest.fixture(params=[None])
def load_version(request):
return request.param


@pytest.fixture(params=[None])
def save_version(request):
return request.param or generate_timestamp()


@pytest.fixture()
def dummy_dataframe():
return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
Expand All @@ -62,6 +76,180 @@ def make_data_set(load_args=None, save_args=None):
return make_data_set


@pytest.fixture
def versioned_blob_csv_data_set(load_version, save_version):
return CSVBlobDataSet(
filepath=TEST_FILE_NAME,
container_name=TEST_CONTAINER_NAME,
credentials=TEST_CREDENTIALS,
blob_to_text_args={"to_extra": 41},
blob_from_text_args={"from_extra": 42},
version=Version(load_version, save_version),
)


@pytest.fixture
def save_path(save_version):
return "{0}/{1}/{0}".format(TEST_FILE_NAME, save_version)


class TestCSVBlobDataSetVersioned:
# pylint: disable=too-many-arguments
@patch("kedro.contrib.io.azure.csv_blob.BlockBlobService.create_blob_from_text")
@patch(
"kedro.contrib.io.azure.csv_blob.BlockBlobService.exists", return_value=False
)
@patch("kedro.contrib.io.azure.csv_blob.CSVBlobDataSet._get_load_path")
def test_save(
self,
load_mock,
exists_mock,
save_mock,
versioned_blob_csv_data_set,
dummy_dataframe,
save_path,
):
"""Test that saving saves with a correct version"""
versioned_blob_csv_data_set.save(dummy_dataframe)
save_mock.assert_called_with(
container_name=TEST_CONTAINER_NAME,
blob_name=save_path,
text=dummy_dataframe.to_csv(index=False),
from_extra=42,
)

@patch("kedro.contrib.io.azure.csv_blob.CSVBlobDataSet._get_load_path")
@patch("kedro.contrib.io.azure.csv_blob.BlockBlobService.get_blob_to_text")
def test_load(self, get_blob_mock, load_mock, versioned_blob_csv_data_set):
load_mock.return_value = TEST_FILE_NAME
get_blob_mock.return_value = BlobMock()
result = versioned_blob_csv_data_set.load()
get_blob_mock.assert_called_once_with(
container_name=TEST_CONTAINER_NAME, blob_name=TEST_FILE_NAME, to_extra=41
)
expected = pd.read_csv(io.StringIO(BlobMock().content))
assert_frame_equal(result, expected)

@patch(
"kedro.contrib.io.azure.csv_blob.BlockBlobService.list_blob_names",
return_value=[],
)
@patch(
"kedro.contrib.io.azure.csv_blob.BlockBlobService.exists", return_value=False
)
def test_no_versions(self, exists_mock, list_mock, versioned_blob_csv_data_set):
"""Check the error if no versions are available for load."""
pattern = r"Did not find any versions for CSVBlobDataSet\(.+\)"
with pytest.raises(DataSetError, match=pattern):
versioned_blob_csv_data_set.load()

# pylint: disable=too-many-arguments
@patch("kedro.contrib.io.azure.csv_blob.BlockBlobService.create_blob_from_text")
@patch(
"kedro.contrib.io.azure.csv_blob.BlockBlobService.exists", return_value=False
)
@patch("kedro.contrib.io.azure.csv_blob.CSVBlobDataSet._get_load_path")
def test_exists(
self,
load_mock,
exists_mock,
save_mock,
versioned_blob_csv_data_set,
dummy_dataframe,
save_path,
):
versioned_blob_csv_data_set.save(dummy_dataframe)
load_mock.return_value = PurePosixPath(save_path)
versioned_blob_csv_data_set.exists()
exists_mock.assert_called_with(TEST_CONTAINER_NAME, blob_name=save_path)

@patch("kedro.contrib.io.azure.csv_blob.BlockBlobService.create_blob_from_text")
@patch(
"kedro.contrib.io.azure.csv_blob.BlockBlobService.exists", return_value=False
)
@patch("kedro.contrib.io.azure.csv_blob.CSVBlobDataSet._get_load_path")
def test_exists_dataset_error(
self,
load_mock,
exists_mock,
save_mock,
versioned_blob_csv_data_set,
dummy_dataframe,
save_path,
):
versioned_blob_csv_data_set.save(dummy_dataframe)
load_mock.side_effect = DataSetError
assert not versioned_blob_csv_data_set.exists()

@patch("kedro.contrib.io.azure.csv_blob.BlockBlobService.exists", return_value=True)
@patch("kedro.contrib.io.azure.csv_blob.CSVBlobDataSet._get_load_path")
def test_prevent_override(
self, load_mock, exists_mock, versioned_blob_csv_data_set, dummy_dataframe
):
"""Check the error when attempting to override the data set if the
corresponding csv file for a given save version already exists in S3.
"""
pattern = (
r"Save path \`.+\` for CSVBlobDataSet\(.+\) must not exist "
r"if versioning is enabled"
)
with pytest.raises(DataSetError, match=pattern):
versioned_blob_csv_data_set.save(dummy_dataframe)

@patch("kedro.contrib.io.azure.csv_blob.BlockBlobService.create_blob_from_text")
@patch("kedro.contrib.io.azure.csv_blob.CSVBlobDataSet._get_save_path")
@patch("kedro.contrib.io.azure.csv_blob.CSVBlobDataSet._get_load_path")
def test_save_version_warning(
self,
load_mock,
save_mock,
create_blob_mock,
versioned_blob_csv_data_set,
dummy_dataframe,
):
"""Check the warning when saving to the path that differs from
the subsequent load path."""
save_version = "2019-01-02T00.00.00.000Z"
load_version = "2019-01-01T23.59.59.999Z"
pattern = (
r"Save path `{f}/{sv}/{f}` did not match load path "
r"`{f}/{lv}/{f}` for CSVBlobDataSet\(.+\)".format(
f=TEST_FILE_NAME, sv=save_version, lv=load_version
)
)
load_mock.return_value = PurePosixPath(
"{0}/{1}/{0}".format(TEST_FILE_NAME, load_version)
)
save_mock.return_value = PurePosixPath(
"{0}/{1}/{0}".format(TEST_FILE_NAME, save_version)
)
with pytest.warns(UserWarning, match=pattern):
versioned_blob_csv_data_set.save(dummy_dataframe)

def test_version_str_repr(self, load_version, save_version):
"""Test that version is in string representation of the class instance
when applicable."""
ds = CSVBlobDataSet(
filepath=TEST_FILE_NAME,
container_name=TEST_CONTAINER_NAME,
credentials=TEST_CREDENTIALS,
)
ds_versioned = CSVBlobDataSet(
filepath=TEST_FILE_NAME,
container_name=TEST_CONTAINER_NAME,
credentials=TEST_CREDENTIALS,
version=Version(load_version, save_version),
)
assert TEST_FILE_NAME in str(ds)
assert "version" not in str(ds)

assert TEST_FILE_NAME in str(ds_versioned)
ver_str = "version=Version(load={}, save='{}')".format(
load_version, save_version
)
assert ver_str in str(ds_versioned)


@patch("kedro.contrib.io.azure.csv_blob.BlockBlobService")
def test_pass_credentials_load(blob_service, blob_csv_data_set):
try:
Expand Down Expand Up @@ -119,7 +307,7 @@ def test_load(get_blob_mock, blob_csv_data_set):
result = blob_csv_data_set().load()[["name", "age"]]
expected = pd.DataFrame({"name": ["tom", "bob"], "age": [3, 4]})
expected = expected[["name", "age"]]
assert result.equals(expected)
assert_frame_equal(result, expected)


@patch("kedro.contrib.io.azure.csv_blob.BlockBlobService.create_blob_from_text")
Expand Down

0 comments on commit 0c5d98e

Please sign in to comment.