Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement plotly.HTMLDataset #788

Merged
merged 8 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@
|-------------------------------------|-----------------------------------------------------------|-----------------------------------------|
| `pytorch.PyTorchDataset` | A dataset for securely saving and loading PyTorch models | `kedro_datasets_experimental.pytorch` |

* Added the following new core datasets:

| Type | Description | Location |
|----------------------|------------------------------------------------|-------------------------|
| `plotly.HTMLDataset` | A dataset for saving a `plotly` figure as HTML | `kedro_datasets.plotly` |

## Bug fixes and other changes
## Breaking Changes
## Community contributions
Many thanks to the following Kedroids for contributing PRs to this release:
* [Brandon Meek](https://github.com/bpmeek)
* [yury-fedotov](https://github.com/yury-fedotov)


# Release 4.1.0
Expand All @@ -23,6 +30,7 @@ Many thanks to the following Kedroids for contributing PRs to this release:
## Breaking Changes
## Community contributions


# Release 4.0.0
## Major features and improvements

Expand Down
1 change: 1 addition & 0 deletions kedro-datasets/docs/source/api/kedro_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ kedro_datasets
partitions.PartitionedDataset
pickle.PickleDataset
pillow.ImageDataset
plotly.HTMLDataset
plotly.JSONDataset
plotly.PlotlyDataset
polars.CSVDataset
Expand Down
2 changes: 2 additions & 0 deletions kedro-datasets/kedro_datasets/plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
# https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901
JSONDataset: Any
PlotlyDataset: Any
HTMLDataset: Any

__getattr__, __dir__, __all__ = lazy.attach(
__name__,
submod_attrs={
"html_dataset": ["HTMLDataset"],
"json_dataset": ["JSONDataset"],
"plotly_dataset": ["PlotlyDataset"],
},
Expand Down
154 changes: 154 additions & 0 deletions kedro-datasets/kedro_datasets/plotly/html_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""``HTMLDataset`` saves a plotly figure to an HTML file using an underlying
filesystem (e.g.: local, S3, GCS).
"""
from __future__ import annotations

from copy import deepcopy
from pathlib import PurePosixPath
from typing import Any, NoReturn, Union

import fsspec
from kedro.io.core import (
AbstractVersionedDataset,
DatasetError,
Version,
get_filepath_str,
get_protocol_and_path,
)
from plotly import graph_objects as go


class HTMLDataset(
AbstractVersionedDataset[go.Figure, Union[go.Figure, go.FigureWidget]]
):
"""``HTMLDataset`` saves a plotly figure to an HTML file using an
underlying filesystem (e.g.: local, S3, GCS).

Example usage for the
`YAML API <https://kedro.readthedocs.io/en/stable/data/\
data_catalog_yaml_examples.html>`_:

.. code-block:: yaml

scatter_plot:
type: plotly.HTMLDataset
filepath: data/08_reporting/scatter_plot.html
save_args:
auto_open: False

Example usage for the
`Python API <https://kedro.readthedocs.io/en/stable/data/\
advanced_data_catalog_usage.html>`_:

.. code-block:: pycon

>>> from kedro_datasets.plotly import HTMLDataset
>>> import plotly.express as px
>>>
>>> fig = px.bar(x=["a", "b", "c"], y=[1, 3, 2])
>>> dataset = HTMLDataset(filepath=tmp_path / "test.html")
>>> dataset.save(fig)
"""

DEFAULT_SAVE_ARGS: dict[str, Any] = {}
DEFAULT_FS_ARGS: dict[str, Any] = {
"open_args_save": {"mode": "w", "encoding": "utf-8"}
}

def __init__( # noqa: PLR0913
self,
*,
filepath: str,
save_args: dict[str, Any] | None = None,
version: Version | None = None,
credentials: dict[str, Any] | None = None,
fs_args: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Creates a new instance of ``HTMLDataset`` pointing to a concrete HTML file
on a specific filesystem.

Args:
filepath: Filepath in POSIX format to an HTML file prefixed with a protocol like `s3://`.
If prefix is not provided `file` protocol (local filesystem) will be used.
The prefix should be any protocol supported by ``fsspec``.
Note: `http(s)` doesn't support versioning.
save_args: Plotly options for saving HTML files.
Here you can find all available arguments:
https://plotly.com/python-api-reference/generated/plotly.io.write_html.html#plotly.io.write_html
All defaults are preserved.
version: If specified, should be an instance of
``kedro.io.core.Version``. If its ``load`` attribute is
None, the latest version will be loaded. If its ``save``
attribute is None, save version will be autogenerated.
credentials: Credentials required to get access to the underlying filesystem.
E.g. for ``GCSFileSystem`` it should look like `{'token': None}`.
fs_args: Extra arguments to pass into underlying filesystem class constructor
(e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as
to pass to the filesystem's `open` method through nested keys
`open_args_load` and `open_args_save`.
Here you can find all available arguments for `open`:
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
All defaults are preserved, except `mode`, which is set to `w` when
saving.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
_fs_args = deepcopy(fs_args) or {}
_fs_open_args_save = _fs_args.pop("open_args_save", {})
_credentials = deepcopy(credentials) or {}

protocol, path = get_protocol_and_path(filepath, version)
if protocol == "file":
_fs_args.setdefault("auto_mkdir", True)

self._protocol = protocol
self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args)

self.metadata = metadata

super().__init__(
filepath=PurePosixPath(path),
version=version,
exists_function=self._fs.exists,
glob_function=self._fs.glob,
)

# Handle default save and fs arguments
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}
self._fs_open_args_save = {
**self.DEFAULT_FS_ARGS.get("open_args_save", {}),
**(_fs_open_args_save or {}),
}

def _describe(self) -> dict[str, Any]:
return {
"filepath": self._filepath,
"protocol": self._protocol,
"save_args": self._save_args,
"version": self._version,
}

def _load(self) -> NoReturn:
raise DatasetError(f"Loading not supported for '{self.__class__.__name__}'")

def _save(self, data: go.Figure) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
data.write_html(fs_file, **self._save_args)

self._invalidate_cache()

def _exists(self) -> bool:
load_path = get_filepath_str(self._get_load_path(), self._protocol)

return self._fs.exists(load_path)

def _release(self) -> None:
super()._release()
self._invalidate_cache()

def _invalidate_cache(self) -> None:
filepath = get_filepath_str(self._filepath, self._protocol)
self._fs.invalidate_cache(filepath)
3 changes: 2 additions & 1 deletion kedro-datasets/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,10 @@ pickle = ["kedro-datasets[pickle-pickledataset]"]
pillow-imagedataset = ["Pillow~=9.0"]
pillow = ["kedro-datasets[pillow-imagedataset]"]

plotly-htmldataset = ["kedro-datasets[plotly-base]"]
plotly-jsondataset = ["kedro-datasets[plotly-base]"]
plotly-plotlydataset = ["kedro-datasets[pandas-base,plotly-base]"]
plotly = ["kedro-datasets[plotly-jsondataset,plotly-plotlydataset]"]
plotly = ["kedro-datasets[plotly-htmldataset,plotly-jsondataset,plotly-plotlydataset]"]

polars-csvdataset = ["kedro-datasets[polars-base]"]
polars-eagerpolarsdataset = ["kedro-datasets[polars-base]", "pyarrow>=4.0", "xlsx2csv>=0.8.0", "deltalake >= 0.6.2"]
Expand Down
88 changes: 88 additions & 0 deletions kedro-datasets/tests/plotly/test_html_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from pathlib import PurePosixPath

import plotly.express as px
import pytest
from adlfs import AzureBlobFileSystem
from fsspec.implementations.http import HTTPFileSystem
from fsspec.implementations.local import LocalFileSystem
from gcsfs import GCSFileSystem
from kedro.io.core import PROTOCOL_DELIMITER, DatasetError
from s3fs.core import S3FileSystem

from kedro_datasets.plotly import HTMLDataset


@pytest.fixture
def filepath_html(tmp_path):
return (tmp_path / "test.html").as_posix()


@pytest.fixture
def html_dataset(filepath_html, save_args, fs_args):
return HTMLDataset(
filepath=filepath_html,
save_args=save_args,
fs_args=fs_args,
)


@pytest.fixture
def dummy_plot():
return px.scatter(x=[1, 2, 3], y=[1, 3, 2], title="Test")


class TestHTMLDataset:
def test_save(self, html_dataset, dummy_plot):
"""Test saving and reloading the data set."""
html_dataset.save(dummy_plot)
assert html_dataset._fs_open_args_save == {"mode": "w", "encoding": "utf-8"}

def test_exists(self, html_dataset, dummy_plot):
"""Test `exists` method invocation for both existing and
nonexistent data set."""
assert not html_dataset.exists()
html_dataset.save(dummy_plot)
assert html_dataset.exists()

def test_load_is_impossible(self, html_dataset):
"""Check the error when trying to load a dataset."""
pattern = "Loading not supported"
with pytest.raises(DatasetError, match=pattern):
html_dataset.load()

@pytest.mark.parametrize("save_args", [{"auto_play": False}])
def test_save_extra_params(self, html_dataset, save_args):
"""Test overriding default save args"""
for k, v in save_args.items():
assert html_dataset._save_args[k] == v

@pytest.mark.parametrize(
"filepath,instance_type,credentials",
[
("s3://bucket/file.html", S3FileSystem, {}),
("file:///tmp/test.html", LocalFileSystem, {}),
("/tmp/test.html", LocalFileSystem, {}),
("gcs://bucket/file.html", GCSFileSystem, {}),
("https://example.com/file.html", HTTPFileSystem, {}),
(
"abfs://bucket/file.csv",
AzureBlobFileSystem,
{"account_name": "test", "account_key": "test"},
),
],
)
def test_protocol_usage(self, filepath, instance_type, credentials):
dataset = HTMLDataset(filepath=filepath, credentials=credentials)
assert isinstance(dataset._fs, instance_type)

path = filepath.split(PROTOCOL_DELIMITER, 1)[-1]

assert str(dataset._filepath) == path
assert isinstance(dataset._filepath, PurePosixPath)

def test_catalog_release(self, mocker):
fs_mock = mocker.patch("fsspec.filesystem").return_value
filepath = "test.html"
dataset = HTMLDataset(filepath=filepath)
dataset.release()
fs_mock.invalidate_cache.assert_called_once_with(filepath)