diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index c26fb2fbc0..5446be9a41 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -327,10 +327,10 @@ jobs: - win_setup: python_version: <> - when: - # Save Python package cache only for Python 3.7. The conda environment itself + # Save Python package cache only for Python 3.8. The conda environment itself # is specific to a Python version and is cached separately for each. condition: - equal: ["3.7", <>] + equal: [<>] steps: - save_cache: name: Save Python package cache @@ -458,7 +458,7 @@ jobs: publish_kedro: executor: name: docker - python_version: "3.7" + python_version: "3.8" steps: - setup - add_ssh_keys @@ -528,7 +528,7 @@ workflows: - lint: matrix: parameters: - python_version: ["3.7", "3.8", "3.9", "3.10"] + python_version: ["3.8", "3.9", "3.10"] - all_circleci_checks_succeeded: requires: - lint @@ -544,31 +544,31 @@ workflows: - e2e_tests: matrix: parameters: - python_version: ["3.7", "3.8", "3.9", "3.10"] + python_version: ["3.8", "3.9", "3.10"] - win_e2e_tests: matrix: parameters: - python_version: ["3.7", "3.8", "3.9", "3.10"] + python_version: ["3.8", "3.9", "3.10"] - unit_tests: matrix: parameters: - python_version: ["3.7", "3.8", "3.9", "3.10"] + python_version: ["3.8", "3.9", "3.10"] - win_unit_tests: matrix: parameters: - python_version: ["3.7", "3.8", "3.9", "3.10"] + python_version: ["3.8", "3.9", "3.10"] - lint: matrix: parameters: - python_version: ["3.7", "3.8", "3.9", "3.10"] + python_version: ["3.8", "3.9", "3.10"] - pip_compile: matrix: parameters: - python_version: ["3.7", "3.8", "3.9", "3.10"] + python_version: ["3.8", "3.9", "3.10"] - win_pip_compile: matrix: parameters: - python_version: ["3.7", "3.8", "3.9", "3.10"] + python_version: ["3.8", "3.9", "3.10"] - all_circleci_checks_succeeded: requires: - e2e_tests @@ -620,14 +620,14 @@ workflows: - kedro-ecr-publish matrix: parameters: - python_version: ["3.7", "3.8", "3.9", "3.10"] + python_version: ["3.8", "3.9", "3.10"] filters: branches: only: main - build_kedro: matrix: parameters: - python_version: ["3.7", "3.8", "3.9", "3.10"] + python_version: ["3.8", "3.9", "3.10"] requires: - build_docker_image-<> filters: @@ -672,7 +672,7 @@ workflows: - build_kedro: matrix: parameters: - python_version: ["3.7", "3.8", "3.9", "3.10"] + python_version: ["3.8", "3.9", "3.10"] - publish_kedro: requires: - build_kedro diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index ab1ea67ca6..b4964ff636 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,6 +1,3 @@ -> **NOTE:** Kedro datasets are moving from `kedro.extras.datasets` to a separate `kedro-datasets` package in -> [`kedro-plugins` repository](https://github.com/kedro-org/kedro-plugins). Any changes to the dataset implementations -> should be done by opening a pull request in that repository. ## Description diff --git a/.github/workflows/all-checks.yml b/.github/workflows/all-checks.yml index 5fe306d018..61e0cf81a6 100644 --- a/.github/workflows/all-checks.yml +++ b/.github/workflows/all-checks.yml @@ -22,7 +22,7 @@ jobs: strategy: matrix: os: [ ubuntu-latest, windows-latest ] - python-version: [ "3.7", "3.8", "3.9", "3.10", "3.11" ] + python-version: [ "3.8", "3.9", "3.10", "3.11" ] uses: ./.github/workflows/unit-tests.yml with: os: ${{ matrix.os }} @@ -42,7 +42,7 @@ jobs: strategy: matrix: os: [ ubuntu-latest, windows-latest ] - python-version: [ "3.7", "3.8", "3.9", "3.10", "3.11" ] + python-version: [ "3.8", "3.9", "3.10", "3.11" ] uses: ./.github/workflows/e2e-tests.yml with: os: ${{ matrix.os }} @@ -52,7 +52,7 @@ jobs: strategy: matrix: os: [ ubuntu-latest, windows-latest ] - python-version: [ "3.7", "3.8", "3.9", "3.10", "3.11" ] + python-version: [ "3.8", "3.9", "3.10", "3.11" ] uses: ./.github/workflows/pip-compile.yml with: os: ${{ matrix.os }} diff --git a/.github/workflows/docs-only-checks.yml b/.github/workflows/docs-only-checks.yml index d13a76c589..1af4aa53e8 100644 --- a/.github/workflows/docs-only-checks.yml +++ b/.github/workflows/docs-only-checks.yml @@ -21,7 +21,7 @@ jobs: strategy: matrix: os: [ ubuntu-latest ] - python-version: [ "3.7", "3.8", "3.9", "3.10", "3.11" ] + python-version: [ "3.8", "3.9", "3.10", "3.11" ] uses: ./.github/workflows/lint.yml with: os: ${{ matrix.os }} diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 347940ea86..5d2c982b87 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -36,14 +36,7 @@ jobs: run: | make install-test-requirements make install-pre-commit - - name: Install pytables (only for windows) - if: inputs.os == 'windows-latest' - run: pip install tables - name: pip freeze run: pip freeze - name: Run unit tests - if: inputs.os == 'ubuntu-latest' run: make test - - name: Run unit tests without spark (Windows) - if: inputs.os == 'windows-latest' - run: make test-no-spark diff --git a/Makefile b/Makefile index e241e5ffa1..4d768cd0d4 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,8 @@ install: - pip install . + pip install -e . clean: - rm -rf build dist docs/build kedro/html pip-wheel-metadata .mypy_cache .pytest_cache features/steps/test_plugin/test_plugin.egg-info kedro/datasets + rm -rf build dist docs/build kedro/html pip-wheel-metadata .mypy_cache .pytest_cache features/steps/test_plugin/test_plugin.egg-info find . -regex ".*/__pycache__" -exec rm -rf {} + find . -regex ".*\.egg-info" -exec rm -rf {} + pre-commit clean || true @@ -12,18 +12,6 @@ lint: test: pytest --numprocesses 4 --dist loadfile -test-no-spark: - pytest --no-cov --ignore tests/extras/datasets/spark --numprocesses 4 --dist loadfile - -test-sequential: - pytest tests --cov-config pyproject.toml - -test-no-spark-sequential: - pytest tests --no-cov --ignore tests/extras/datasets/spark - -test-no-datasets: - pytest --no-cov --ignore tests/extras/datasets/ --numprocesses 4 --dist loadfile - e2e-tests: behave @@ -33,8 +21,6 @@ pip-compile: secret-scan: trufflehog --max_depth 1 --exclude_paths trufflehog-ignore.txt . -SPHINXPROJ = Kedro - build-docs: pip install -e ".[docs]" ./docs/build-docs.sh "docs" diff --git a/RELEASE.md b/RELEASE.md index c79395c06c..091f6a0790 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -5,8 +5,29 @@ ## Bug fixes and other changes ## Breaking changes to the API +* Drop Python 3.7 support. + +### DataSets +* Remove `kedro.extras.datasets` and tests. +* Reduced constructor arguments for `APIDataSet` by replacing most arguments with a single constructor argument `load_args`. This makes it more consistent with other Kedro DataSets and the underlying `requests` API, and automatically enables the full configuration domain: stream, certificates, proxies, and more. + +### CLI +* Removed deprecated `kedro docs` command. + +### ConfigLoader +* `logging` is removed from `ConfigLoader` in favour of the environment variable `KEDRO_LOGGING_CONFIG`. + +### Other +* Removed deprecated `kedro.extras.ColorHandler`. +* The Kedro IPython extension is no longer available as `%load_ext kedro.extras.extensions.ipython`; use `%load_ext kedro.ipython` instead. +* Anonymous nodes are given default names of the form `([in1;in2;...]) -> [out1;out2;...]`, with the names of inputs and outputs separated by semicolons. ## Migration guide from Kedro 0.18.* to 0.19.* +### DataSets +* If you use `APIDataSet`, move all `requests` specific arguments (e.g. `params`, `headers`), except for `url` and `method`, to under `load_args`. +### Logging +`logging.yml` is now independent of Kedro's run environment and only used if `KEDRO_LOGGING_CONFIG` is set to point to it. + # Upcoming Release 0.18.14 ## Major features and improvements @@ -106,6 +127,10 @@ Thanks to [Laíza Milena Scheid Parizotto](https://github.com/laizaparizotto) an ## Documentation changes * Significant improvements to the documentation that covers working with Databricks and Kedro, including a new page for workspace-only development, and a guide to choosing the best workflow for your use case. * Updated documentation for deploying with Prefect for version 2.0. +* Added documentation for developing a Kedro project using a Databricks workspace. + +## Breaking changes to the API +* Logging is decoupled from `ConfigLoader`, use `KEDRO_LOGGING_CONFIG` to configure logging. ## Upcoming deprecations for Kedro 0.19.0 * Renamed dataset and error classes, in accordance with the [Kedro lexicon](https://github.com/kedro-org/kedro/wiki/Kedro-documentation-style-guide#kedro-lexicon). Dataset classes ending with "DataSet" and error classes starting with "DataSet" are deprecated and will be removed in 0.19.0. Note that all of the below classes are also importable from `kedro.io`; only the module where they are defined is listed as the location. @@ -190,6 +215,7 @@ Many thanks to the following Kedroids for contributing PRs to this release: * [MaximeSteinmetz](https://github.com/MaximeSteinmetz) + # Release 0.18.7 ## Major features and improvements @@ -202,6 +228,7 @@ Many thanks to the following Kedroids for contributing PRs to this release: * Added a guide and tooling for developing Kedro for Databricks. * Implemented missing dict-like interface for `_ProjectPipeline`. + # Release 0.18.6 ## Bug fixes and other changes @@ -216,6 +243,7 @@ A regression introduced in Kedro version `0.18.5` caused the `Kedro-Viz` console Thanks to Kedroids tomohiko kato, [tsanikgr](https://github.com/tsanikgr) and [maddataanalyst](https://github.com/maddataanalyst) for very detailed reports about the bug. + # Release 0.18.5 > This release introduced a bug that causes a failure in experiment tracking within the `Kedro-Viz` console. We recommend that you use Kedro version `0.18.6` in preference. @@ -390,6 +418,7 @@ We are grateful to the following for submitting PRs that contributed to this rel ## Bug fixes and other changes * Removed fatal error from being logged when a Kedro session is created in a directory without git. +* `KedroContext` is now an `attrs`'s frozen class and `config_loader` is available as public attribute. * Fixed `CONFIG_LOADER_CLASS` validation so that `TemplatedConfigLoader` can be specified in settings.py. Any `CONFIG_LOADER_CLASS` must be a subclass of `AbstractConfigLoader`. * Added runner name to the `run_params` dictionary used in pipeline hooks. * Updated [Databricks documentation](https://docs.kedro.org/en/0.18.1/deployment/databricks.html) to include how to get it working with IPython extension and Kedro-Viz. @@ -403,6 +432,10 @@ We are grateful to the following for submitting PRs that contributed to this rel ## Upcoming deprecations for Kedro 0.19.0 * `kedro docs` will be removed in 0.19.0. +## Upcoming deprecations for Kedro 0.19.0 +* `kedro docs` will be removed in 0.19.0. + + # Release 0.18.0 ## TL;DR ✨ diff --git a/docs/source/conf.py b/docs/source/conf.py index e80f9b2b29..77bad9c020 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -136,7 +136,6 @@ "CONF_SOURCE", "integer -- return number of occurrences of value", "integer -- return first index of value.", - "kedro.extras.datasets.pandas.json_dataset.JSONDataSet", "kedro_datasets.pandas.json_dataset.JSONDataSet", "pluggy._manager.PluginManager", "PluginManager", @@ -319,8 +318,6 @@ "kedro.pipeline", "kedro.runner", "kedro.config", - "kedro.extras.datasets", - "kedro.extras.logging", "kedro_datasets", ] diff --git a/docs/source/configuration/advanced_configuration.md b/docs/source/configuration/advanced_configuration.md index d4b7e12b6b..f536afe70d 100644 --- a/docs/source/configuration/advanced_configuration.md +++ b/docs/source/configuration/advanced_configuration.md @@ -148,7 +148,7 @@ CONFIG_LOADER_ARGS = { By changing this setting, the default behaviour for loading parameters will be replaced, while the other configuration patterns will remain in their default state. ### How to ensure non default configuration files get loaded -You can add configuration patterns to match files other than `parameters`, `credentials`, `logging`, and `catalog` by setting the `CONFIG_LOADER_ARGS` variable in [`src//settings.py`](../kedro_project_setup/settings.md). +You can add configuration patterns to match files other than `parameters`, `credentials`, and `catalog` by setting the `CONFIG_LOADER_ARGS` variable in [`src//settings.py`](../kedro_project_setup/settings.md). For example, if you want to load Spark configuration files you need to update `CONFIG_LOADER_ARGS` as follows: ```python @@ -160,7 +160,7 @@ CONFIG_LOADER_ARGS = { ``` ### How to bypass the configuration loading rules -You can bypass the configuration patterns and set configuration directly on the instance of a config loader class. You can bypass the default configuration (catalog, parameters, credentials, and logging) as well as additional configuration. +You can bypass the configuration patterns and set configuration directly on the instance of a config loader class. You can bypass the default configuration (catalog, parameters and credentials) as well as additional configuration. ```{code-block} python :lineno-start: 10 diff --git a/docs/source/configuration/configuration_basics.md b/docs/source/configuration/configuration_basics.md index 9d53d49e2f..32a5e37946 100644 --- a/docs/source/configuration/configuration_basics.md +++ b/docs/source/configuration/configuration_basics.md @@ -1,6 +1,6 @@ # Configuration -This section contains detailed information about Kedro project configuration, which you can use to store settings for your project such as [parameters](./parameters.md), [credentials](./credentials.md), the [data catalog](../data/data_catalog.md), and [logging information](../logging/logging.md). +This section contains detailed information about Kedro project configuration, which you can use to store settings for your project such as [parameters](./parameters.md), [credentials](./credentials.md), the [data catalog](../data/data_catalog.md), and [logging information](../logging/index.md). Kedro makes use of a configuration loader to load any project configuration files, and the available configuration loader classes are: diff --git a/docs/source/deployment/aws_batch.md b/docs/source/deployment/aws_batch.md index 5856701608..5ff9ffa9bc 100644 --- a/docs/source/deployment/aws_batch.md +++ b/docs/source/deployment/aws_batch.md @@ -169,7 +169,7 @@ class AWSBatchRunner(ThreadRunner): return super()._get_required_workers_count(pipeline) - def _run( # pylint: disable=too-many-locals,useless-suppression + def _run( self, pipeline: Pipeline, catalog: DataCatalog, diff --git a/docs/source/deployment/databricks/databricks_ide_development_workflow.md b/docs/source/deployment/databricks/databricks_ide_development_workflow.md index 80bac20815..0f32ee540a 100644 --- a/docs/source/deployment/databricks/databricks_ide_development_workflow.md +++ b/docs/source/deployment/databricks/databricks_ide_development_workflow.md @@ -5,7 +5,7 @@ This guide demonstrates a workflow for developing Kedro projects on Databricks u By working in your local environment, you can take advantage of features within an IDE that are not available on Databricks notebooks: - Auto-completion and suggestions for code, improving your development speed and accuracy. -- Linters like Pylint or Flake8 can be integrated to catch potential issues in your code. +- Linters like [Ruff](https://docs.astral.sh/ruff) can be integrated to catch potential issues in your code. - Static type checkers like Mypy can check types in your code, helping to identify potential type-related issues early in the development process. To set up these features, look for instructions specific to your IDE (for instance, [VS Code](https://code.visualstudio.com/docs/python/linting)). diff --git a/docs/source/deployment/databricks/databricks_notebooks_development_workflow.md b/docs/source/deployment/databricks/databricks_notebooks_development_workflow.md index ef2081a28a..b3136d9a9d 100644 --- a/docs/source/deployment/databricks/databricks_notebooks_development_workflow.md +++ b/docs/source/deployment/databricks/databricks_notebooks_development_workflow.md @@ -19,7 +19,7 @@ This tutorial introduces a Kedro project development workflow using only the Dat - An active [Databricks deployment](https://docs.databricks.com/getting-started/index.html). - A [Databricks cluster](https://docs.databricks.com/clusters/configure.html) configured with a recent version (>= 11.3 is recommended) of the Databricks runtime. -- Python >= 3.7 installed. +- Python >= 3.8 installed. - Git installed. - A [GitHub](https://github.com/) account. - A Python environment management system installed, [venv](https://docs.python.org/3/library/venv.html), [virtualenv](https://virtualenv.pypa.io/en/latest/) or [Conda](https://docs.conda.io/en/latest/) are popular choices. diff --git a/docs/source/development/commands_reference.md b/docs/source/development/commands_reference.md index 5403f1b563..4dffb241b6 100644 --- a/docs/source/development/commands_reference.md +++ b/docs/source/development/commands_reference.md @@ -53,23 +53,17 @@ Here is a list of Kedro CLI commands, as a shortcut to the descriptions below. P * Global Kedro commands * [`kedro --help`](#get-help-on-kedro-commands) * [`kedro --version`](#confirm-the-kedro-version) - * [`kedro docs`](#open-the-kedro-documentation-in-your-browser) * [`kedro info`](#confirm-kedro-information) * [`kedro new`](#create-a-new-kedro-project) * Project-specific Kedro commands - * [`kedro activate-nbstripout`](#strip-output-cells)(deprecated from version 0.19.0) - * [`kedro build-docs`](#build-the-project-documentation) (deprecated from version 0.19.0) - * [`kedro build-reqs`](#build-the-projects-dependency-tree) (deprecated from version 0.19.0) * [`kedro catalog list`](#list-datasets-per-pipeline-per-type) * [`kedro catalog resolve`](#resolve-dataset-factories-in-the-catalog) * [`kedro catalog rank`](#rank-dataset-factories-in-the-catalog) * [`kedro catalog create`](#create-a-data-catalog-yaml-configuration-file) * [`kedro ipython`](#notebooks) - * [`kedro jupyter convert`](#copy-tagged-cells) (deprecated from version 0.19.0) * [`kedro jupyter lab`](#notebooks) * [`kedro jupyter notebook`](#notebooks) - * [`kedro lint`](#lint-your-project) (deprecated from version 0.19.0) * [`kedro micropkg package `](#package-a-micro-package) * [`kedro micropkg pull `](#pull-a-micro-package) * [`kedro package`](#deploy-the-project) @@ -78,7 +72,6 @@ Here is a list of Kedro CLI commands, as a shortcut to the descriptions below. P * [`kedro registry describe `](#describe-a-registered-pipeline) * [`kedro registry list`](#list-all-registered-pipelines-in-your-project) * [`kedro run`](#run-the-project) - * [`kedro test`](#test-your-project) (deprecated from version 0.19.0) ## Global Kedro commands @@ -133,12 +126,6 @@ kedro_viz: 4.4.0 (hooks:global,line_magic) kedro new ``` -### Open the Kedro documentation in your browser - -```bash -kedro docs -``` - ## Customise or Override Project-specific Kedro commands ```{note} @@ -239,7 +226,6 @@ def cli(): help=PARAMS_ARG_HELP, callback=_split_params, ) -# pylint: disable=too-many-arguments,unused-argument def run( tag, env, @@ -285,21 +271,6 @@ def run( ### Project setup -#### Build the project's dependency tree - -```{note} -_This command will be deprecated from Kedro version 0.19.0._ -``` -```bash -kedro build-reqs -``` - -This command runs [`pip-compile`](https://github.com/jazzband/pip-tools#example-usage-for-pip-compile) on the project's `src/requirements.txt` file and will create `src/requirements.lock` with the compiled requirements. - -`kedro build-reqs` has two optional arguments to specify which file to compile the requirements from and where to save the compiled requirements to. These arguments are `--input-file` and `--output-file` respectively. - -`kedro build-reqs` also accepts and passes through CLI options accepted by `pip-compile`. For example, `kedro build-reqs --generate-hashes` will call `pip-compile --output-file=src/requirements.lock --generate-hashes src/requirements.txt`. - #### Install all package dependencies The following runs [`pip`](https://github.com/pypa/pip) to install all package dependencies specified in `src/requirements.txt`: @@ -388,44 +359,6 @@ The above command will take the bundled `.tar.gz` file and do the following: ### Project quality -#### Build the project documentation - -```{note} -_This command will be deprecated from Kedro version 0.19.0._ -``` - -```bash -kedro build-docs -``` - -The `build-docs` command builds [project documentation](../tutorial/package_a_project.md#add-documentation-to-a-kedro-project) using the [Sphinx](https://www.sphinx-doc.org) framework. To further customise your documentation, please refer to `docs/source/conf.py` and the [Sphinx documentation](http://www.sphinx-doc.org/en/master/usage/configuration.html). - - -#### Lint your project - -```{note} -_This command will be deprecated from Kedro version 0.19.0._. We still recommend to (../development/linting.md) and you can find more help here -``` - -```bash -kedro lint -``` - -Your project is linted with [`black`](https://github.com/psf/black), [`flake8`](https://github.com/PyCQA/flake8) and [`isort`](https://github.com/PyCQA/isort). - - -#### Test your project - -```{note} -_This command will be deprecated from Kedro version 0.19.0._ -``` - -The following runs all `pytest` unit tests found in `src/tests`, including coverage (see the file `.coveragerc`): - -```bash -kedro test -``` - ### Project development #### Modular pipelines @@ -551,29 +484,3 @@ The [Kedro IPython extension](../notebooks_and_ipython/kedro_and_notebooks.md#a- * `session` (type `KedroSession`): [Kedro session](../kedro_project_setup/session.md) that orchestrates a pipeline run To reload these variables (e.g. if you updated `catalog.yml`) use the `%reload_kedro` line magic, which can also be used to see the error message if any of the variables above are undefined. - -##### Copy tagged cells - -```{note} -_This command will be deprecated from Kedro version 0.19.0._ -``` - -To copy the code from [cells tagged](https://jupyter-notebook.readthedocs.io/en/stable/changelog.html#cell-tags) with a `node` tag into Python files under `src//nodes/` in a Kedro project: - -```bash -kedro jupyter convert --all -``` - -##### Strip output cells - -```{note} -_This command will be deprecated from Kedro version 0.19.0._ -``` - -Output cells of Jupyter Notebook should not be tracked by git, especially if they contain sensitive information. To strip them out: - -```bash -kedro activate-nbstripout -``` - -This command adds a `git hook` which clears all notebook output cells before committing anything to `git`. It needs to run only once per local repository. diff --git a/docs/source/development/linting.md b/docs/source/development/linting.md index a8bdbc0c44..c4d2631848 100644 --- a/docs/source/development/linting.md +++ b/docs/source/development/linting.md @@ -9,8 +9,7 @@ Linting tools check your code for errors such as a missing bracket or line inden As a project grows and goes through various stages of development it becomes important to maintain code quality. Using a consistent format and linting your code ensures that it is consistent, readable, and easy to debug and maintain. ## Set up Python tools -There are a variety of Python tools available to use with your Kedro projects. This guide shows you how to use -[`black`](https://github.com/psf/black), [`ruff`](https://beta.ruff.rs). +There are a variety of Python tools available to use with your Kedro projects. This guide shows you how to use [`black`](https://github.com/psf/black) and [`ruff`](https://beta.ruff.rs). - **`black`** is a [PEP 8](https://peps.python.org/pep-0008/) compliant opinionated Python code formatter. `black` can check for styling inconsistencies and reformat your files in place. [You can read more in the `black` documentation](https://black.readthedocs.io/en/stable/). @@ -58,29 +57,13 @@ ignore = ["E501"] # Black take care off line-too-long It is a good practice to [split your line when it is too long](https://beta.ruff.rs/docs/rules/line-too-long/), so it can be read easily even in a small screen. `ruff` treats this slightly different from `black`, when using together we recommend to disable this rule, i.e. `E501` to avoid conflicts. ``` -#### Configure `flake8` - -Store your `flake8` configuration in a file named `.flake8` within your project root. The Kedro default project template use the [following configuration](https://github.com/kedro-org/kedro/blob/main/kedro/templates/project/%7B%7B%20cookiecutter.repo_name%20%7D%7D/.flake8): - -```text -[flake8] -max-line-length=88 -extend-ignore=E203 -``` - ### Run the tools Use the following commands to run lint checks: ```bash black --check -isort --profile black --check -``` -You can also have `black` and `isort` automatically format your code by omitting the `--check` flag. Since `isort` and -`black` both format your imports, adding `--profile black` to the `isort` run helps avoid potential conflicts. - -Use the following to invoke `flake8`: -```bash -flake8 +ruff check ``` +You can also have `black` automatically format your code by omitting the `--check` flag. ## Automated formatting and linting with `pre-commit` hooks @@ -101,7 +84,7 @@ pip install pre-commit ### Add `pre-commit` configuration file Create a file named `.pre-commit-config.yaml` in your Kedro project root directory. You can add entries for the hooks you want to run before each `commit`. -Below is a sample `YAML` file with entries for `black`,`flake8`, and `isort`: +Below is a sample `YAML` file with entries for `ruff` and black`: ```yaml repos: - repo: https://github.com/astral-sh/ruff-pre-commit diff --git a/docs/source/development/set_up_pycharm.md b/docs/source/development/set_up_pycharm.md index 8bb1a75df1..b3e62d12aa 100644 --- a/docs/source/development/set_up_pycharm.md +++ b/docs/source/development/set_up_pycharm.md @@ -61,7 +61,7 @@ Specify the **Run / Debug Configuration** name in the **Name** field, and edit t ![](../meta/images/pycharm_edit_py_run_config.png) ```{note} -**Emulate terminal in output console** enables PyCharm to show [rich terminal output](../logging/logging.md#default-framework-side-logging-configuration). +**Emulate terminal in output console** enables PyCharm to show [rich terminal output](../logging/index.md). ``` To execute the Run configuration, select it from the **Run / Debug Configurations** dropdown in the toolbar (if that toolbar is not visible, you can enable it by going to **View > Toolbar**). Click the green triangle: @@ -72,7 +72,7 @@ You may also select **Run** from the toolbar and execute from there.
![](../meta/images/pycharm_conf_run_dropdown.png) -For other `kedro` commands, follow same steps but replace `run` in the `Parameters` field with the other commands that are to be used (e.g., `test`, `package`, `build-docs` etc.). +For other `kedro` commands, follow same steps but replace `run` in the `Parameters` field with the other commands that are to be used (e.g., `jupyter`, `package`, `registry` etc.). ## Debugging diff --git a/docs/source/experiment_tracking/index.md b/docs/source/experiment_tracking/index.md index 3004fe28e0..85c943e05c 100644 --- a/docs/source/experiment_tracking/index.md +++ b/docs/source/experiment_tracking/index.md @@ -135,7 +135,7 @@ export AWS_REGION="your_aws_region" ## Set up experiment tracking datasets -There are two types of tracking datasets: [`tracking.MetricsDataSet`](/kedro.extras.datasets.tracking.MetricsDataSet) and [`tracking.JSONDataSet`](/kedro.extras.datasets.tracking.JSONDataSet). The `tracking.MetricsDataSet` should be used for tracking numerical metrics, and the `tracking.JSONDataSet` can be used for tracking any other JSON-compatible data like boolean or text-based data. +There are two types of tracking datasets: [`tracking.MetricsDataSet`](/kedro_datasets.tracking.MetricsDataSet) and [`tracking.JSONDataSet`](/kedro_datasets.tracking.JSONDataSet). The `tracking.MetricsDataSet` should be used for tracking numerical metrics, and the `tracking.JSONDataSet` can be used for tracking any other JSON-compatible data like boolean or text-based data. Set up two datasets to log the columns used in the companies dataset (`companies_columns`) and experiment metrics for the data science pipeline (`metrics`) like the coefficient of determination (`r2 score`), max error (`me`) and mean absolute error (`mae`) by adding the following in the `conf/base/catalog.yml` file: diff --git a/docs/source/extend_kedro/plugins.md b/docs/source/extend_kedro/plugins.md index c93f180820..84345310d4 100644 --- a/docs/source/extend_kedro/plugins.md +++ b/docs/source/extend_kedro/plugins.md @@ -145,7 +145,7 @@ from kedro.framework.hooks import hook_impl class MyHooks: @hook_impl - def after_catalog_created(self, catalog): # pylint: disable=unused-argument + def after_catalog_created(self, catalog): logging.info("Reached after_catalog_created hook") diff --git a/docs/source/faq/faq.md b/docs/source/faq/faq.md index 69115f5e30..210320dbd3 100644 --- a/docs/source/faq/faq.md +++ b/docs/source/faq/faq.md @@ -8,8 +8,6 @@ This is a growing set of technical FAQs. The [product FAQs on the Kedro website] ## Working with Jupyter -* [How can I convert functions from Jupyter Notebooks into Kedro nodes](../notebooks_and_ipython/kedro_and_notebooks.md#convert-functions-from-jupyter-notebooks-into-kedro-nodes)? - * [How do I connect a Kedro project kernel to other Jupyter clients like JupyterLab](../notebooks_and_ipython/kedro_and_notebooks.md#ipython-jupyterlab-and-other-jupyter-clients)? ## Kedro project development diff --git a/docs/source/get_started/install.md b/docs/source/get_started/install.md index 8afea95a57..5f6db082f5 100644 --- a/docs/source/get_started/install.md +++ b/docs/source/get_started/install.md @@ -1,7 +1,7 @@ # Set up Kedro ## Installation prerequisites -* **Python**: Kedro supports macOS, Linux, and Windows and is built for Python 3.7+. You'll select a version of Python when you create a virtual environment for your Kedro project. +* **Python**: Kedro supports macOS, Linux, and Windows and is built for Python 3.8+. You'll select a version of Python when you create a virtual environment for your Kedro project. * **Virtual environment**: You should create a new virtual environment for *each* new Kedro project you work on to isolate its Python dependencies from those of other projects. @@ -23,7 +23,7 @@ The recommended approach. From your terminal: conda create --name kedro-environment python=3.10 -y ``` -The example below uses Python 3.10, and creates a virtual environment called `kedro-environment`. You can opt for a different version of Python (any version >= 3.7 and <3.11) for your project, and you can name it anything you choose. +The example below uses Python 3.10, and creates a virtual environment called `kedro-environment`. You can opt for a different version of Python (any version >= 3.8 and <3.12) for your project, and you can name it anything you choose. The `conda` virtual environment is not dependent on your current working directory and can be activated from any directory: @@ -184,7 +184,7 @@ pip install kedro ## Summary * Kedro can be used on Windows, macOS or Linux. -* Installation prerequisites include a virtual environment manager like `conda`, Python 3.7+, and `git`. +* Installation prerequisites include a virtual environment manager like `conda`, Python 3.8+, and `git`. * You should install Kedro using `pip install kedro`. If you encounter any problems as you set up Kedro, ask for help on Kedro's [Slack organisation](https://slack.kedro.org) or review the [searchable archive of Slack discussions](https://linen-slack.kedro.org/). diff --git a/docs/source/get_started/kedro_concepts.md b/docs/source/get_started/kedro_concepts.md index 67f9bf84a9..d87ed1aac0 100644 --- a/docs/source/get_started/kedro_concepts.md +++ b/docs/source/get_started/kedro_concepts.md @@ -72,7 +72,6 @@ project-dir # Parent directory of the template ├── notebooks # Project-related Jupyter notebooks (can be used for experimental code before moving the code to src) ├── pyproject.toml # Identifies the project root and contains configuration information ├── README.md # Project README -├── .flake8 # Configuration options for `flake8` (linting) └── src # Project source code ``` @@ -104,4 +103,4 @@ The `data` folder contains multiple subfolders to store project data. We recomme ### `src` -This subfolder contains the project's source code in one subfolder and another folder that you can use to add unit tests for your project. Projects are preconfigured to run tests using `pytest` when you call `kedro test` from the project's root directory. +This subfolder contains the project's source code. diff --git a/docs/source/index.rst b/docs/source/index.rst index 5850f15f76..ae1b641064 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,9 +23,9 @@ Welcome to Kedro's documentation! :target: https://opensource.org/license/apache2-0-php/ :alt: License is Apache 2.0 -.. image:: https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9%20%7C%203.10%20%7C%203.11-blue.svg +.. image:: https://img.shields.io/badge/3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11-blue.svg :target: https://pypi.org/project/kedro/ - :alt: Python version 3.7, 3.8, 3.9, 3.10, 3.11 + :alt: Python version 3.8, 3.9, 3.10, 3.11 .. image:: https://badge.fury.io/py/kedro.svg :target: https://pypi.org/project/kedro/ diff --git a/docs/source/introduction/introduction.md b/docs/source/introduction/introduction.md index 45fea38bd2..6749a1dc49 100644 --- a/docs/source/introduction/introduction.md +++ b/docs/source/introduction/introduction.md @@ -11,6 +11,6 @@ Use the left-hand table of contents to explore the documentation available for m We have designed the preliminary documentation and the [spaceflights tutorial](../tutorial/spaceflights_tutorial.md) for anyone new to Kedro. The more knowledge of Python you have, the easier you will find the learning curve. ```{note} -There are many excellent online resources for learning Python; you should choose those that reference Python 3, as Kedro is built for Python 3.7+. There are curated lists of online resources, such as the [official Python programming language website](https://www.python.org/) and this list of [free programming books and tutorials](https://github.com/EbookFoundation/free-programming-books/blob/master/books/free-programming-books-langs.md#python). +There are many excellent online resources for learning Python; you should choose those that reference Python 3, as Kedro is built for Python 3.8+. There are curated lists of online resources, such as the [official Python programming language website](https://www.python.org/) and this list of [free programming books and tutorials](https://github.com/EbookFoundation/free-programming-books/blob/master/books/free-programming-books-langs.md#python). ``` diff --git a/docs/source/kedro.extras.datasets.rst b/docs/source/kedro.extras.datasets.rst deleted file mode 100644 index b26ee836e2..0000000000 --- a/docs/source/kedro.extras.datasets.rst +++ /dev/null @@ -1,51 +0,0 @@ -kedro.extras.datasets -===================== - -.. rubric:: Description - -.. automodule:: kedro.extras.datasets - -.. rubric:: Classes - -.. autosummary:: - :toctree: - :template: autosummary/class.rst - - kedro.extras.datasets.api.APIDataSet - kedro.extras.datasets.biosequence.BioSequenceDataSet - kedro.extras.datasets.dask.ParquetDataSet - kedro.extras.datasets.email.EmailMessageDataSet - kedro.extras.datasets.geopandas.GeoJSONDataSet - kedro.extras.datasets.holoviews.HoloviewsWriter - kedro.extras.datasets.json.JSONDataSet - kedro.extras.datasets.matplotlib.MatplotlibWriter - kedro.extras.datasets.networkx.GMLDataSet - kedro.extras.datasets.networkx.GraphMLDataSet - kedro.extras.datasets.networkx.JSONDataSet - kedro.extras.datasets.pandas.CSVDataSet - kedro.extras.datasets.pandas.ExcelDataSet - kedro.extras.datasets.pandas.FeatherDataSet - kedro.extras.datasets.pandas.GBQQueryDataSet - kedro.extras.datasets.pandas.GBQTableDataSet - kedro.extras.datasets.pandas.GenericDataSet - kedro.extras.datasets.pandas.HDFDataSet - kedro.extras.datasets.pandas.JSONDataSet - kedro.extras.datasets.pandas.ParquetDataSet - kedro.extras.datasets.pandas.SQLQueryDataSet - kedro.extras.datasets.pandas.SQLTableDataSet - kedro.extras.datasets.pandas.XMLDataSet - kedro.extras.datasets.pickle.PickleDataSet - kedro.extras.datasets.pillow.ImageDataSet - kedro.extras.datasets.plotly.JSONDataSet - kedro.extras.datasets.plotly.PlotlyDataSet - kedro.extras.datasets.redis.PickleDataSet - kedro.extras.datasets.spark.DeltaTableDataSet - kedro.extras.datasets.spark.SparkDataSet - kedro.extras.datasets.spark.SparkHiveDataSet - kedro.extras.datasets.spark.SparkJDBCDataSet - kedro.extras.datasets.svmlight.SVMLightDataSet - kedro.extras.datasets.tensorflow.TensorFlowModelDataset - kedro.extras.datasets.text.TextDataSet - kedro.extras.datasets.tracking.JSONDataSet - kedro.extras.datasets.tracking.MetricsDataSet - kedro.extras.datasets.yaml.YAMLDataSet diff --git a/docs/source/kedro.extras.logging.color_logger.ColorHandler.rst b/docs/source/kedro.extras.logging.color_logger.ColorHandler.rst deleted file mode 100644 index 8a762bb2c7..0000000000 --- a/docs/source/kedro.extras.logging.color_logger.ColorHandler.rst +++ /dev/null @@ -1,6 +0,0 @@ -kedro.extras.logging.color\_logger.ColorHandler -=============================================== - -.. currentmodule:: kedro.extras.logging.color_logger - -.. autoclass:: ColorHandler diff --git a/docs/source/kedro.extras.rst b/docs/source/kedro.extras.rst deleted file mode 100644 index 0980b1a41b..0000000000 --- a/docs/source/kedro.extras.rst +++ /dev/null @@ -1,20 +0,0 @@ -kedro.extras -============ - -.. rubric:: Description - -.. automodule:: kedro.extras - -.. rubric:: Modules - -.. autosummary:: - :toctree: - :recursive: - - kedro.extras.extensions - kedro.extras.logging - -.. toctree:: - :hidden: - - kedro.extras.datasets diff --git a/docs/source/kedro_project_setup/starters.md b/docs/source/kedro_project_setup/starters.md index 305fe1de00..9454f6505d 100644 --- a/docs/source/kedro_project_setup/starters.md +++ b/docs/source/kedro_project_setup/starters.md @@ -155,7 +155,6 @@ Here is the layout of the project as a Cookiecutter template: ├── docs # Project documentation ├── notebooks # Project related Jupyter notebooks (can be used for experimental code before moving the code to src) ├── README.md # Project README -├── .flake8 # Configuration options for `flake8` (linting) └── src # Project source code └── {{ cookiecutter.python_package }} ├── __init.py__ diff --git a/docs/source/logging/index.md b/docs/source/logging/index.md index a3750bb188..23fb20d26d 100644 --- a/docs/source/logging/index.md +++ b/docs/source/logging/index.md @@ -1,19 +1,185 @@ # Logging -Kedro uses [Python's `logging` library](https://docs.python.org/3/library/logging.html). Configuration is provided as a dictionary according to the [Python logging configuration schema](https://docs.python.org/3/library/logging.config.html#logging-config-dictschema) in two places: -1. [Default configuration built into the Kedro framework](https://github.com/kedro-org/kedro/blob/main/kedro/framework/project/default_logging.yml). This cannot be altered. -2. Your project-side logging configuration. Every project generated using Kedro's CLI `kedro new` command includes a file `conf/base/logging.yml`. You can alter this configuration and provide different configurations for different run environment according to the [standard Kedro mechanism for handling configuration](../configuration/configuration_basics.md). +Kedro uses [Python's `logging` library](https://docs.python.org/3/library/logging.html). Configuration is provided as a dictionary according to the [Python logging configuration schema](https://docs.python.org/3/library/logging.config.html#logging-config-dictschema) in Kedro's default logging configuration, as described below. + +By default, Python only shows logging messages at level `WARNING` and above. Kedro's logging configuration specifies that `INFO` level messages from Kedro should also be emitted. This makes it easier to track the progress of your pipeline when you perform a `kedro run`. + +# Default logging configuration +Kedro's [default logging configuration](https://github.com/kedro-org/kedro/blob/main/kedro/framework/project/default_logging.yml) defines a handler called `rich` that uses the [Rich logging handler](https://rich.readthedocs.io) to format messages. We also use the Rich traceback handler to render exceptions. + +## How to perform logging in your Kedro project +To add logging to your own code (e.g. in a node): + +```python +import logging + +logger = logging.getLogger(__name__) +logger.warning("Issue warning") +logger.info("Send information") +logger.debug("Useful information for debugging") +``` + +You can use Rich's [console markup](https://rich.readthedocs.io/en/stable/markup.html) in your logging calls: + +```python +log.error("[bold red blink]Important error message![/]", extra={"markup": True}) +``` + +## How to customise Kedro logging + +To customise logging in your Kedro project, you need to specify the path to a project-specific logging configuration file. Change the environment variable `KEDRO_LOGGING_CONFIG` to override the default logging configuration. Point the variable instead to your project-specific configuration, which we recommend you store inside the project's`conf` folder, and name `logging.yml`. + +For example, you can set `KEDRO_LOGGING_CONFIG` by typing the following into your terminal: + +```bash +export KEDRO_LOGGING_CONFIG=/conf/logging.yml +``` + +After setting the environment variable, any subsequent Kedro commands use the logging configuration file at the specified path. ```{note} -Providing project-side logging configuration is entirely optional. You can delete the `conf/base/logging.yml` file and Kedro will run using the framework's built in configuration. +If the `KEDRO_LOGGING_CONFIG` environment variable is not set, Kedro will use the [default logging configuration](https://github.com/kedro-org/kedro/blob/main/kedro/framework/project/default_logging.yml). ``` -Framework-side and project-side logging configuration are loaded through subsequent calls to [`logging.config.dictConfig`](https://docs.python.org/3/library/logging.config.html#logging.config.dictConfig). This means that, when it is provided, the project-side logging configuration typically _fully overwrites_ the framework-side logging configuration. [Incremental configuration](https://docs.python.org/3/library/logging.config.html#incremental-configuration) is also possible if the `incremental` key is explicitly set to `True` in your project-side logging configuration. +### How to show DEBUG level messages +To see `DEBUG` level messages, change the level of logging in your project-specific logging configuration file (`logging.yml`). We provide a `logging.yml` template: + +
+Click to expand the logging.yml template + +```yaml +version: 1 -```{toctree} -:hidden: +disable_existing_loggers: False -logging +formatters: + simple: + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + +handlers: + console: + class: logging.StreamHandler + level: INFO + formatter: simple + stream: ext://sys.stdout + + info_file_handler: + class: logging.handlers.RotatingFileHandler + level: INFO + formatter: simple + filename: info.log + maxBytes: 10485760 # 10MB + backupCount: 20 + encoding: utf8 + delay: True + + rich: + class: kedro.logging.RichHandler + rich_tracebacks: True + # Advance options for customisation. + # See https://docs.kedro.org/en/stable/logging/logging.html#project-side-logging-configuration + # tracebacks_show_locals: False + +loggers: + kedro: + level: INFO + + your_python_package: + level: INFO + +root: + handlers: [rich] +``` + +
+ +You need to change the line: +```diff +loggers: + kedro: + level: INFO + + your_python_package: +- level: INFO ++ level: DEBUG +``` + +```{note} +The name of a logger corresponds to a key in the `loggers` section of the logging configuration file (e.g. `kedro`). See [Python's logging documentation](https://docs.python.org/3/library/logging.html#logger-objects) for more information. ``` + +By changing the level value to `DEBUG` for the desired logger (e.g. ``), you will start seeing `DEBUG` level messages in the log output. + +## Advanced logging + +In addition to the `rich` handler defined in Kedro's framework, we provide two additional handlers in the template. + +* `console`: show logs on standard output (typically your terminal screen) without any rich formatting +* `info_file_handler`: write logs of level `INFO` and above to `info.log` + +The following section illustrates some common examples of how to change your project's logging configuration. + +## How to customise the `rich` handler + +Kedro's `kedro.logging.RichHandler` is a subclass of [`rich.logging.RichHandler`](https://rich.readthedocs.io/en/stable/reference/logging.html#rich.logging.RichHandler) and supports the same set of arguments. By default, `rich_tracebacks` is set to `True` to use `rich` to render exceptions. However, you can disable it by setting `rich_tracebacks: False`. + +```{note} +If you want to disable `rich`'s tracebacks, you must set `KEDRO_LOGGING_CONFIG` to point to your local config i.e. `conf/logging.yml`. +``` + +When `rich_tracebacks` is set to `True`, the configuration is propagated to [`rich.traceback.install`](https://rich.readthedocs.io/en/stable/reference/traceback.html#rich.traceback.install). If an argument is compatible with `rich.traceback.install`, it will be passed to the traceback's settings. + +For instance, you can enable the display of local variables inside `logging.yml` to aid with debugging. + +```diff + rich: + class: kedro.logging.RichHandler + rich_tracebacks: True ++ tracebacks_show_locals: True +``` + +A comprehensive list of available options can be found in the [RichHandler documentation](https://rich.readthedocs.io/en/stable/reference/logging.html#rich.logging.RichHandler). + +## How to enable file-based logging + +File-based logging in Python projects aids troubleshooting and debugging. It offers better visibility into application's behaviour and it's easy to search. However, it does not work well with read-only systems such as [Databricks Repos](https://docs.databricks.com/repos/index.html). + +To enable file-based logging, add `info_file_handler` in your `root` logger as follows in your `conf/logging.yml` as follows: + +```diff + root: +- handlers: [rich] ++ handlers: [rich, info_file_handler] +``` + +By default it only tracks `INFO` level messages, but it can be configured to capture any level of logs. + +## How to use plain console logging + +To use plain rather than rich logging, swap the `rich` handler for the `console` one as follows: + +```diff + root: +- handlers: [rich] ++ handlers: [console] +``` + +## How to enable rich logging in a dumb terminal + +Rich [detects whether your terminal is capable](https://rich.readthedocs.io/en/stable/console.html#terminal-detection) of displaying richly formatted messages. If your terminal is "dumb" then formatting is automatically stripped out so that the logs are just plain text. This is likely to happen if you perform `kedro run` on CI (e.g. GitHub Actions or CircleCI). + +If you find that the default wrapping of the log messages is too narrow but do not wish to switch to using the `console` logger on CI then the simplest way to control the log message wrapping is through altering the `COLUMNS` and `LINES` environment variables. For example: + +```bash +export COLUMNS=120 LINES=25 +``` + +```{note} +You must provide a value for both `COLUMNS` and `LINES` even if you only wish to change the width of the log message. Rich's default values for these variables are `COLUMNS=80` and `LINE=25`. +``` + +## How to enable rich logging in Jupyter + +Rich also formats the logs in JupyterLab and Jupyter Notebook. The size of the output console does not adapt to your window but can be controlled through the `JUPYTER_COLUMNS` and `JUPYTER_LINES` environment variables. The default values (115 and 100 respectively) should be suitable for most users, but if you require a different output console size then you should alter the values of `JUPYTER_COLUMNS` and `JUPYTER_LINES`. diff --git a/docs/source/logging/logging.md b/docs/source/logging/logging.md deleted file mode 100644 index 93f3332660..0000000000 --- a/docs/source/logging/logging.md +++ /dev/null @@ -1,115 +0,0 @@ - -# Default framework-side logging configuration - -Kedro's [default logging configuration](https://github.com/kedro-org/kedro/blob/main/kedro/framework/project/default_logging.yml) defines a handler called `rich` that uses the [Rich logging handler](https://rich.readthedocs.io/en/stable/logging.html) to format messages. We also use the [Rich traceback handler](https://rich.readthedocs.io/en/stable/traceback.html) to render exceptions. - -By default, Python only shows logging messages at level `WARNING` and above. Kedro's logging configuration specifies that `INFO` level messages from Kedro should also be emitted. This makes it easier to track the progress of your pipeline when you perform a `kedro run`. - -## Project-side logging configuration - -In addition to the `rich` handler defined in Kedro's framework, the [project-side `conf/base/logging.yml`](https://github.com/kedro-org/kedro/blob/main/kedro/templates/project/%7B%7B%20cookiecutter.repo_name%20%7D%7D/conf/base/logging.yml) defines two further logging handlers: -* `console`: show logs on standard output (typically your terminal screen) without any rich formatting -* `info_file_handler`: write logs of level `INFO` and above to `info.log` - -The logging handlers that are actually used by default are `rich` and `info_file_handler`. - -The project-side logging configuration also ensures that [logs emitted from your project's logger](#perform-logging-in-your-project) should be shown if they are `INFO` level or above (as opposed to the Python default of `WARNING`). - -We now give some common examples of how you might like to change your project's logging configuration. - -### Using `KEDRO_LOGGING_CONFIG` environment variable - -`KEDRO_LOGGING_CONFIG` is an optional environment variable that you can use to specify the path of your logging configuration file, overriding the default Kedro's `default_logging.yml`. - -To use this environment variable, set it to the path of your desired logging configuration file before running any Kedro commands. For example, if you have a logging configuration file located at `/path/to/logging.yml`, you can set `KEDRO_LOGGING_CONFIG` as follows: - -```bash -export KEDRO_LOGGING_CONFIG=/path/to/logging.yml -``` - -After setting the environment variable, any subsequent Kedro commands will use the logging configuration file at the specified path. - -```{note} -If the `KEDRO_LOGGING_CONFIG` environment variable is not set, Kedro will default to using the logging configuration file at the project's default location of Kedro's `default_logging.yml`. -``` -### Disable file-based logging - -You might sometimes need to disable file-based logging, e.g. if you are running Kedro on a read-only file system such as [Databricks Repos](https://docs.databricks.com/repos/index.html). The simplest way to do this is to delete your `conf/base/logging.yml` file. With no project-side logging configuration specified, Kedro uses the default framework-side logging configuration, which does not include any file-based handlers. - -Alternatively, if you would like to keep other configuration in `conf/base/logging.yml` and just disable file-based logging, then you can remove the file-based handlers from the `root` logger as follows: -```diff - root: -- handlers: [console, info_file_handler] -+ handlers: [console] -``` - -### Customise the `rich` Handler - -Kedro's `kedro.extras.logging.RichHandler` is a subclass of [`rich.logging.RichHandler`](https://rich.readthedocs.io/en/stable/reference/logging.html#rich.logging.RichHandler) and supports the same set of arguments. By default, `rich_tracebacks` is set to `True` to use `rich` to render exceptions. However, you can disable it by setting `rich_tracebacks: False`. - -```{note} -If you want to disable `rich`'s tracebacks, you must set `KEDRO_LOGGING_CONFIG` to point to your local config i.e. `conf/base/logging.yml`. -``` - -When `rich_tracebacks` is set to `True`, the configuration is propagated to [`rich.traceback.install`](https://rich.readthedocs.io/en/stable/reference/traceback.html#rich.traceback.install). If an argument is compatible with `rich.traceback.install`, it will be passed to the traceback's settings. - -For instance, you can enable the display of local variables inside `logging.yml` to aid with debugging. - -```yaml -rich: - class: kedro.extras.logging.RichHandler - rich_tracebacks: True - tracebacks_show_locals: True -``` - -A comprehensive list of available options can be found in the [RichHandler documentation](https://rich.readthedocs.io/en/stable/reference/logging.html#rich.logging.RichHandler). - - -### Use plain console logging - -To use plain rather than rich logging, swap the `rich` handler for the `console` one as follows: - -```diff - root: -- handlers: [rich, info_file_handler] -+ handlers: [console, info_file_handler] -``` - -### Rich logging in a dumb terminal - -Rich [detects whether your terminal is capable](https://rich.readthedocs.io/en/stable/console.html#terminal-detection) of displaying richly formatted messages. If your terminal is "dumb" then formatting is automatically stripped out so that the logs are just plain text. This is likely to happen if you perform `kedro run` on CI (e.g. GitHub Actions or CircleCI). - -If you find that the default wrapping of the log messages is too narrow but do not wish to switch to using the `console` logger on CI then the simplest way to control the log message wrapping is through altering the `COLUMNS` and `LINES` environment variables. For example: - -```bash -export COLUMNS=120 LINES=25 -``` - -```{note} -You must provide a value for both `COLUMNS` and `LINES` even if you only wish to change the width of the log message. Rich's default values for these variables are `COLUMNS=80` and `LINE=25`. -``` - -### Rich logging in Jupyter - -Rich also formats the logs in JupyterLab and Jupyter Notebook. The size of the output console does not adapt to your window but can be controlled through the `JUPYTER_COLUMNS` and `JUPYTER_LINES` environment variables. The default values (115 and 100 respectively) should be suitable for most users, but if you require a different output console size then you should alter the values of `JUPYTER_COLUMNS` and `JUPYTER_LINES`. - -## Perform logging in your project - -To perform logging in your own code (e.g. in a node), you are advised to do as follows: - -```python -import logging - -log = logging.getLogger(__name__) -log.warning("Issue warning") -log.info("Send information") -``` - -```{note} -The name of a logger corresponds to a key in the `loggers` section in `logging.yml` (e.g. `kedro`). See [Python's logging documentation](https://docs.python.org/3/library/logging.html#logger-objects) for more information. -``` - -You can take advantage of rich's [console markup](https://rich.readthedocs.io/en/stable/markup.html) when enabled in your logging calls: -```python -log.error("[bold red blink]Important error message![/]", extra={"markup": True}) -``` diff --git a/docs/source/nodes_and_pipelines/nodes.md b/docs/source/nodes_and_pipelines/nodes.md index 80902a2cb6..11bc17e960 100644 --- a/docs/source/nodes_and_pipelines/nodes.md +++ b/docs/source/nodes_and_pipelines/nodes.md @@ -295,7 +295,7 @@ import pandas as pd from kedro.io.core import ( get_filepath_str, ) -from kedro.extras.datasets.pandas import CSVDataSet +from kedro_datasets.pandas import CSVDataSet class ChunkWiseCSVDataset(CSVDataSet): diff --git a/docs/source/notebooks_and_ipython/kedro_and_notebooks.md b/docs/source/notebooks_and_ipython/kedro_and_notebooks.md index 0cd509b32c..133e4a13b9 100644 --- a/docs/source/notebooks_and_ipython/kedro_and_notebooks.md +++ b/docs/source/notebooks_and_ipython/kedro_and_notebooks.md @@ -192,32 +192,6 @@ For more details, run `%reload_kedro?`. If you have [Kedro-Viz](https://github.com/kedro-org/kedro-viz) installed for the project you can display an interactive visualisation of your pipeline directly in your Notebook using the [line magic](https://ipython.readthedocs.io/en/stable/interactive/magics.html) `%run_viz`. - -## Convert functions from Jupyter Notebooks into Kedro nodes - -If you are writing experimental code in your Notebook and later want to convert functions you've written to Kedro nodes, you can do this using tags. - -Say you have the following code in your Notebook: - -```ipython -def some_action(): - print("This function came from `notebooks/my_notebook.ipynb`") -``` - -1. Enable tags toolbar: `View` menu -> `Cell Toolbar` -> `Tags` -![Enable the tags toolbar graphic](../meta/images/jupyter_notebook_workflow_activating_tags.png) - -2. Add the `node` tag to the cell containing your function -![Add the node tag graphic](../meta/images/jupyter_notebook_workflow_tagging_nodes.png) - -```{note} -The Notebook can contain multiple functions tagged as `node`, each of them will be exported into the resulting Python file -``` - -3. Save your Jupyter Notebook to `notebooks/my_notebook.ipynb` -4. From your terminal, run `kedro jupyter convert notebooks/my_notebook.ipynb` from the Kedro project directory. The output is a Python file `src//nodes/my_notebook.py` containing the `some_action` function definition -5. The `some_action` function can now be used in your Kedro pipelines - ## Useful to know... Each Kedro project has its own Jupyter kernel so you can switch between multiple Kedro projects from a single Jupyter instance simply by selecting the appropriate kernel. diff --git a/docs/source/tutorial/package_a_project.md b/docs/source/tutorial/package_a_project.md index 010aed2e6c..b0e22fead1 100644 --- a/docs/source/tutorial/package_a_project.md +++ b/docs/source/tutorial/package_a_project.md @@ -15,13 +15,6 @@ pip install sphinx ``` ### Set up the Sphinx project files - -```{warning} -Currently, Kedro projects are created with a `docs/source` subdirectory, which gets pre-populated with two Sphinx configuration files (`conf.py`, and `index.rst`), needed by the `kedro build-docs` command. This command is deprecated; it will be removed in Kedro version 0.19, along with those dummy files. - -Before proceeding with these instructions, back up the contents of `docs/source/index.rst` and remove both `docs/source/conf.py` and `docs/source/index.rst`. -``` - First, run the following command: ```bash diff --git a/docs/source/tutorial/tutorial_template.md b/docs/source/tutorial/tutorial_template.md index 99b75cb031..71b8564c55 100644 --- a/docs/source/tutorial/tutorial_template.md +++ b/docs/source/tutorial/tutorial_template.md @@ -34,14 +34,15 @@ The spaceflights project dependencies are stored in `src/requirements.txt`(you m # code quality packages black==22.0 flake8>=3.7.9, <5.0 -ipython>=7.31.1, <8.0 +ipython>=7.31.1, <8.0; python_version < '3.8' +ipython~=8.10; python_version >= '3.8' isort~=5.0 -nbstripout~=0.4 # notebook tooling jupyter~=1.0 jupyterlab~=3.0 jupyterlab_server>=2.11.1, <2.16.0 +nbstripout~=0.4 # Pytest + useful extensions pytest-cov~=3.0 @@ -49,8 +50,8 @@ pytest-mock>=1.7.1, <2.0 pytest~=7.2 # Kedro dependencies and datasets to work with different data formats (including CSV, Excel, and Parquet) -kedro~=0.18.10 -kedro-datasets[pandas.CSVDataSet, pandas.ExcelDataSet, pandas.ParquetDataSet]~=1.1 +kedro~=0.18.13 +kedro-datasets[pandas.CSVDataSet, pandas.ExcelDataSet, pandas.ParquetDataSet]~=1.0 kedro-telemetry~=0.2.0 kedro-viz~=6.0 # Visualise pipelines @@ -68,7 +69,7 @@ pip install -r src/requirements.txt ## Optional: logging and configuration -You might want to [set up logging](../logging/logging.md) at this stage of the workflow, but we do not use it in this tutorial. +You might want to [set up logging](../logging/index.md) at this stage of the workflow, but we do not use it in this tutorial. You may also want to store credentials such as usernames and passwords if they are needed for specific data sources used by the project. diff --git a/features/activate_nbstripout.feature b/features/activate_nbstripout.feature deleted file mode 100644 index fa221417d4..0000000000 --- a/features/activate_nbstripout.feature +++ /dev/null @@ -1,17 +0,0 @@ -Feature: Activate_nbstripout target in new project - - Scenario: Check nbstripout git post commit hook functionality - Given I have prepared a config file - And I have run a non-interactive kedro new with starter "default" - And I have added a test jupyter notebook - And I have initialized a git repository - And I have added the project directory to staging - And I have committed changes to git - And I have executed the kedro command "activate-nbstripout" - When I execute the test jupyter notebook and save changes - And I add the project directory to staging - And I commit changes to git - And I remove the notebooks directory - And I perform a hard git reset to restore the project to last commit - Then there should be an additional cell in the jupyter notebook - And the output should be empty in all the cells in the jupyter notebook diff --git a/features/build_docs.feature b/features/build_docs.feature deleted file mode 100644 index c9f9307ef1..0000000000 --- a/features/build_docs.feature +++ /dev/null @@ -1,11 +0,0 @@ -Feature: build-docs target in new project - - @fresh_venv - Scenario: Execute build-docs target - Given I have prepared a config file - And I have run a non-interactive kedro new with starter "default" - And I have updated kedro requirements - And I have installed the project dependencies - When I execute the kedro command "build-docs" - Then I should get a successful exit code - And docs should be generated diff --git a/features/build_reqs.feature b/features/build_reqs.feature deleted file mode 100644 index 05bf551961..0000000000 --- a/features/build_reqs.feature +++ /dev/null @@ -1,13 +0,0 @@ -Feature: build-reqs target in new project - - @fresh_venv - Scenario: Execute build-reqs target - Given I have prepared a config file - And I have run a non-interactive kedro new with starter "default" - And I have updated kedro requirements - And I have executed the kedro command "build-reqs --resolver=backtracking" - When I add scrapy>=1.7.3 to the requirements - And I execute the kedro command "build-reqs --resolver=backtracking" - Then I should get a successful exit code - And requirements should be generated - And scrapy should be in the requirements diff --git a/features/environment.py b/features/environment.py index ea51cc9a56..4e801fecaf 100644 --- a/features/environment.py +++ b/features/environment.py @@ -103,8 +103,6 @@ def _setup_minimal_env(context): "pip", "install", "-U", - # pip==23.2 breaks pip-tools<7.0, and pip-tools>=7.0 does not support Python 3.7 - "pip>=21.2,<23.2; python_version < '3.8'", "pip>=21.2; python_version >= '3.8'", ], env=context.env, @@ -122,6 +120,6 @@ def _install_project_requirements(context): .splitlines() ) install_reqs = [req for req in install_reqs if "{" not in req and "#" not in req] - install_reqs.append(".[pandas.CSVDataSet]") + install_reqs.append("kedro-datasets[pandas.CSVDataSet]") call([context.pip, "install", *install_reqs], env=context.env) return context diff --git a/features/jupyter.feature b/features/jupyter.feature index 65b5173442..2769ff8b20 100644 --- a/features/jupyter.feature +++ b/features/jupyter.feature @@ -17,11 +17,3 @@ Feature: Jupyter targets in new project When I execute the kedro jupyter command "lab --no-browser" Then I wait for the jupyter webserver to run for up to "120" seconds Then Jupyter Lab should run on port 8888 - - Scenario: Execute node convert into Python files - Given I have added a test jupyter notebook - When I execute the test jupyter notebook and save changes - And I execute the kedro jupyter command "convert --all" - And Wait until the process is finished for up to "120" seconds - Then I should get a successful exit code - And Code cell with node tag should be converted into kedro node diff --git a/features/load_context.feature b/features/load_context.feature index 7930cf2b8b..e5bcb2dcd2 100644 --- a/features/load_context.feature +++ b/features/load_context.feature @@ -1,7 +1,7 @@ Feature: Custom Kedro project Background: Given I have prepared a config file - And I have run a non-interactive kedro new with starter "default" + And I have run a non-interactive kedro new with starter "default" Scenario: Update the source directory to be nested When I move the package to "src/nested" diff --git a/features/package.feature b/features/package.feature index 663ea87c49..e840a1b9d7 100644 --- a/features/package.feature +++ b/features/package.feature @@ -12,14 +12,3 @@ Feature: Package target in new project When I install the project's python package And I execute the installed project package Then I should get a successful exit code - - @fresh_venv - Scenario: Install package after running kedro build-reqs - Given I have updated kedro requirements - When I execute the kedro command "build-reqs --resolver=backtracking" - Then I should get a successful exit code - When I execute the kedro command "package" - Then I should get a successful exit code - When I install the project's python package - And I execute the installed project package - Then I should get a successful exit code diff --git a/features/steps/cli_steps.py b/features/steps/cli_steps.py index f5931828fa..a4b2096fdd 100644 --- a/features/steps/cli_steps.py +++ b/features/steps/cli_steps.py @@ -454,21 +454,17 @@ def is_created(name): @then("the logs should show that {number} nodes were run") def check_one_node_run(context, number): expected_log_line = f"Completed {number} out of {number} tasks" - info_log = context.root_project_dir / "logs" / "info.log" assert expected_log_line in context.result.stdout - assert expected_log_line in info_log.read_text() @then('the logs should show that "{node}" was run') def check_correct_nodes_run(context, node): expected_log_line = f"Running node: {node}" - info_log = context.root_project_dir / "logs" / "info.log" stdout = context.result.stdout assert expected_log_line in stdout, ( "Expected the following message segment to be printed on stdout: " f"{expected_log_line},\nbut got {stdout}" ) - assert expected_log_line in info_log.read_text(), info_log.read_text() @then("I should get a successful exit code") diff --git a/features/steps/sh_run.py b/features/steps/sh_run.py index 476cee7106..e6f10cbf81 100644 --- a/features/steps/sh_run.py +++ b/features/steps/sh_run.py @@ -36,7 +36,6 @@ def run( """ if isinstance(cmd, str) and split: cmd = shlex.split(cmd) - # pylint: disable=subprocess-run-check result = subprocess.run(cmd, input="", capture_output=True, **kwargs) result.stdout = result.stdout.decode("utf-8") result.stderr = result.stderr.decode("utf-8") diff --git a/features/steps/test_starter/{{ cookiecutter.repo_name }}/.flake8 b/features/steps/test_starter/{{ cookiecutter.repo_name }}/.flake8 deleted file mode 100644 index 8dd399ab55..0000000000 --- a/features/steps/test_starter/{{ cookiecutter.repo_name }}/.flake8 +++ /dev/null @@ -1,3 +0,0 @@ -[flake8] -max-line-length = 88 -extend-ignore = E203 diff --git a/features/steps/test_starter/{{ cookiecutter.repo_name }}/README.md b/features/steps/test_starter/{{ cookiecutter.repo_name }}/README.md index fd206a315d..72d74e4060 100644 --- a/features/steps/test_starter/{{ cookiecutter.repo_name }}/README.md +++ b/features/steps/test_starter/{{ cookiecutter.repo_name }}/README.md @@ -4,7 +4,7 @@ This is your new Kedro project, which was generated using `Kedro {{ cookiecutter.kedro_version }}`. -Take a look at the [Kedro documentation](https://kedro.readthedocs.io) to get started. +Take a look at the [Kedro documentation](https://docs.kedro.org) to get started. ## Rules and guidelines @@ -38,7 +38,7 @@ kedro run Have a look at the file `src/tests/test_run.py` for instructions on how to write your tests. You can run your tests as follows: ``` -kedro test +pytest ``` To configure the coverage threshold, look at the `.coveragerc` file. @@ -46,17 +46,9 @@ To configure the coverage threshold, look at the `.coveragerc` file. ## Project dependencies -To generate or update the dependency requirements for your project: +To see and update the dependency requirements for your project use `src/requirements.txt`. You can install the project requirements with `pip install -r src/requirements.txt`. -``` -kedro build-reqs -``` - -This will `pip-compile` the contents of `src/requirements.txt` into a new file `src/requirements.lock`. You can see the output of the resolution by opening `src/requirements.lock`. - -After this, if you'd like to update your project requirements, please update `src/requirements.txt` and re-run `kedro build-reqs`. - -[Further information about project dependencies](https://kedro.readthedocs.io/en/stable/kedro_project_setup/dependencies.html#project-specific-dependencies) +[Further information about project dependencies](https://docs.kedro.org/en/stable/kedro_project_setup/dependencies.html#project-specific-dependencies) ## How to work with Kedro and notebooks @@ -95,27 +87,11 @@ And if you want to run an IPython session: kedro ipython ``` -### How to convert notebook cells to nodes in a Kedro project -You can move notebook code over into a Kedro project structure using a mixture of [cell tagging](https://jupyter-notebook.readthedocs.io/en/stable/changelog.html#id35) and Kedro CLI commands. - -By adding the `node` tag to a cell and running the command below, the cell's source code will be copied over to a Python file within `src//nodes/`: - -``` -kedro jupyter convert -``` -> *Note:* The name of the Python file matches the name of the original notebook. - -Alternatively, you may want to transform all your notebooks in one go. Run the following command to convert all notebook files found in the project root directory and under any of its sub-folders: - -``` -kedro jupyter convert --all -``` - ### How to ignore notebook output cells in `git` -To automatically strip out all output cell contents before committing to `git`, you can run `kedro activate-nbstripout`. This will add a hook in `.git/config` which will run `nbstripout` before anything is committed to `git`. +To automatically strip out all output cell contents before committing to `git`, you can use tools like [`nbstripout`](https://github.com/kynan/nbstripout). For example, you can add a hook in `.git/config` with `nbstripout --install`. This will run `nbstripout` before anything is committed to `git`. > *Note:* Your output cells will be retained locally. ## Package your Kedro project -[Further information about building project documentation and packaging your project](https://kedro.readthedocs.io/en/stable/tutorial/package_a_project.html) +[Further information about building project documentation and packaging your project](https://docs.kedro.org/en/stable/tutorial/package_a_project.html) diff --git a/features/steps/test_starter/{{ cookiecutter.repo_name }}/conf/base/logging.yml b/features/steps/test_starter/{{ cookiecutter.repo_name }}/conf/base/logging.yml index d60b7d3592..984cac5069 100644 --- a/features/steps/test_starter/{{ cookiecutter.repo_name }}/conf/base/logging.yml +++ b/features/steps/test_starter/{{ cookiecutter.repo_name }}/conf/base/logging.yml @@ -17,17 +17,7 @@ handlers: class: logging.handlers.RotatingFileHandler level: INFO formatter: simple - filename: logs/info.log - maxBytes: 10485760 # 10MB - backupCount: 20 - encoding: utf8 - delay: True - - error_file_handler: - class: logging.handlers.RotatingFileHandler - level: ERROR - formatter: simple - filename: logs/errors.log + filename: info.log maxBytes: 10485760 # 10MB backupCount: 20 encoding: utf8 @@ -48,4 +38,4 @@ loggers: level: INFO root: - handlers: [rich, info_file_handler, error_file_handler] + handlers: [rich] diff --git a/features/steps/test_starter/{{ cookiecutter.repo_name }}/pyproject.toml b/features/steps/test_starter/{{ cookiecutter.repo_name }}/pyproject.toml index ca5524efc1..9208373121 100644 --- a/features/steps/test_starter/{{ cookiecutter.repo_name }}/pyproject.toml +++ b/features/steps/test_starter/{{ cookiecutter.repo_name }}/pyproject.toml @@ -3,9 +3,6 @@ project_name = "{{ cookiecutter.project_name }}" project_version = "{{ cookiecutter.kedro_version }}" package_name = "{{ cookiecutter.python_package }}" -[tool.isort] -profile = "black" - [tool.pytest.ini_options] addopts = """ --cov-report term-missing \ diff --git a/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/pyproject.toml b/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/pyproject.toml index ea581e7028..73b6242480 100644 --- a/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/pyproject.toml +++ b/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/pyproject.toml @@ -19,7 +19,6 @@ docs = [ "sphinx~=3.4.3", "sphinx_rtd_theme==0.5.1", "nbsphinx==0.8.1", - "nbstripout~=0.4", "sphinx-autodoc-typehints==1.11.1", "sphinx_copybutton==0.3.1", "ipykernel>=5.3, <7.0", diff --git a/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/requirements.txt b/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/requirements.txt index 7e6f29ac16..39d93ecd53 100644 --- a/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/requirements.txt +++ b/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/requirements.txt @@ -1,14 +1,12 @@ black~=22.0 -flake8>=3.7.9, <5.0 ipython>=7.31.1, <8.0; python_version < '3.8' ipython~=8.10; python_version >= '3.8' -isort~=5.0 jupyter~=1.0 jupyterlab_server>=2.11.1, <2.16.0 jupyterlab~=3.0, <3.6.0 -kedro[pandas.CSVDataSet]=={{ cookiecutter.kedro_version }} +kedro~={{ cookiecutter.kedro_version}} +kedro-datasets[pandas.CSVDataSet] kedro-telemetry~=0.2.0 -nbstripout~=0.4 pytest-cov~=3.0 pytest-mock>=1.7.1, <2.0 pytest~=7.2 diff --git a/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/tests/test_run.py b/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/tests/test_run.py index ee11dea542..f9355de19b 100644 --- a/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/tests/test_run.py +++ b/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/tests/test_run.py @@ -5,7 +5,7 @@ project's structure, and in files named test_*.py. They are simply functions named ``test_*`` which test a unit of logic. -To run the tests, run ``kedro test`` from the project root directory. +To run the tests, run ``pytest`` from the project root directory. """ from pathlib import Path diff --git a/features/steps/util.py b/features/steps/util.py index f65a4adfa3..74031232f1 100644 --- a/features/steps/util.py +++ b/features/steps/util.py @@ -63,7 +63,7 @@ def wait_for( try: result = func(**kwargs) return result - except Exception as err: # pylint: disable=broad-except + except Exception as err: if print_error: print(err) diff --git a/features/test.feature b/features/test.feature deleted file mode 100644 index 0d42f336e6..0000000000 --- a/features/test.feature +++ /dev/null @@ -1,13 +0,0 @@ -Feature: Test target in new project - - Background: - Given I have prepared a config file - And I have run a non-interactive kedro new with starter "default" - - Scenario: Execute successful test in new project - When I execute the kedro command "test" - Then I should get a successful exit code - - Scenario: Execute successful lint in new project - When I execute the kedro command "lint --check-only" - Then I should get a successful exit code diff --git a/kedro/config/common.py b/kedro/config/common.py index 35bcdcda89..0db6b637f8 100644 --- a/kedro/config/common.py +++ b/kedro/config/common.py @@ -88,10 +88,10 @@ def _get_config_from_patterns( if common_keys: sorted_keys = ", ".join(sorted(common_keys)) msg = ( - "Config from path '%s' will override the following " - "existing top-level config keys: %s" + "Config from path [magenta]%s[/magenta] will override the " + "following existing top-level config keys: '%s'" ) - _config_logger.info(msg, conf_path, sorted_keys) + _config_logger.info(msg, conf_path, sorted_keys, extra={"markup": True}) config.update(new_conf) processed_files |= set(config_filepaths) @@ -129,7 +129,11 @@ def _load_config_file( try: # Default to UTF-8, which is Python 3 default encoding, to decode the file with open(config_file, encoding="utf8") as yml: - _config_logger.debug("Loading config file: '%s'", config_file) + _config_logger.debug( + "Loading config file: [bright magenta]%s[/bright magenta]", + config_file, + extra={"markup": True}, + ) return { k: v for k, v in anyconfig.load( diff --git a/kedro/config/config.py b/kedro/config/config.py index 1af44fc4ab..5eb7b93921 100644 --- a/kedro/config/config.py +++ b/kedro/config/config.py @@ -58,9 +58,6 @@ class ConfigLoader(AbstractConfigLoader): >>> conf_path = str(project_path / settings.CONF_SOURCE) >>> conf_loader = ConfigLoader(conf_source=conf_path, env="local") >>> - >>> conf_logging = conf_loader["logging"] - >>> logging.config.dictConfig(conf_logging) # set logging conf - >>> >>> conf_catalog = conf_loader["catalog"] >>> conf_params = conf_loader["parameters"] @@ -99,7 +96,6 @@ def __init__( # noqa: too-many-arguments "catalog": ["catalog*", "catalog*/**", "**/catalog*"], "parameters": ["parameters*", "parameters*/**", "**/parameters*"], "credentials": ["credentials*", "credentials*/**", "**/credentials*"], - "logging": ["logging*", "logging*/**", "**/logging*"], } self.config_patterns.update(config_patterns or {}) diff --git a/kedro/config/omegaconf_config.py b/kedro/config/omegaconf_config.py index 89fb7d1458..5d07cb64c3 100644 --- a/kedro/config/omegaconf_config.py +++ b/kedro/config/omegaconf_config.py @@ -50,9 +50,6 @@ class OmegaConfigLoader(AbstractConfigLoader): >>> conf_path = str(project_path / settings.CONF_SOURCE) >>> conf_loader = OmegaConfigLoader(conf_source=conf_path, env="local") >>> - >>> conf_logging = conf_loader["logging"] - >>> logging.config.dictConfig(conf_logging) # set logging conf - >>> >>> conf_catalog = conf_loader["catalog"] >>> conf_params = conf_loader["parameters"] @@ -111,7 +108,6 @@ def __init__( # noqa: too-many-arguments "catalog": ["catalog*", "catalog*/**", "**/catalog*"], "parameters": ["parameters*", "parameters*/**", "**/parameters*"], "credentials": ["credentials*", "credentials*/**", "**/credentials*"], - "logging": ["logging*", "logging*/**", "**/logging*"], "globals": ["globals.yml"], } self.config_patterns.update(config_patterns or {}) diff --git a/kedro/config/templated_config.py b/kedro/config/templated_config.py index 1c343ec41f..615b75fdda 100644 --- a/kedro/config/templated_config.py +++ b/kedro/config/templated_config.py @@ -124,7 +124,6 @@ def __init__( # noqa: too-many-arguments "catalog": ["catalog*", "catalog*/**", "**/catalog*"], "parameters": ["parameters*", "parameters*/**", "**/parameters*"], "credentials": ["credentials*", "credentials*/**", "**/credentials*"], - "logging": ["logging*", "logging*/**", "**/logging*"], } self.config_patterns.update(config_patterns or {}) diff --git a/kedro/extras/__init__.py b/kedro/extras/__init__.py deleted file mode 100644 index 5a7dd9fb59..0000000000 --- a/kedro/extras/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""``kedro.extras`` provides functionality such as datasets and extensions. -""" diff --git a/kedro/extras/datasets/README.md b/kedro/extras/datasets/README.md deleted file mode 100644 index bd93acd6be..0000000000 --- a/kedro/extras/datasets/README.md +++ /dev/null @@ -1,22 +0,0 @@ -# Datasets - -> **Warning** -> `kedro.extras.datasets` is deprecated and will be removed in Kedro 0.19, -> install `kedro-datasets` instead by running `pip install kedro-datasets`. - -Welcome to `kedro.extras.datasets`, the home of Kedro's data connectors. Here you will find `AbstractDataset` implementations created by QuantumBlack and external contributors. - -## What `AbstractDataset` implementations are supported? - -We support a range of data descriptions, including CSV, Excel, Parquet, Feather, HDF5, JSON, Pickle, SQL Tables, SQL Queries, Spark DataFrames and more. We even allow support for working with images. - -These data descriptions are supported with the APIs of `pandas`, `spark`, `networkx`, `matplotlib`, `yaml` and more. - -[The Data Catalog](https://kedro.readthedocs.io/en/stable/data/data_catalog.html) allows you to work with a range of file formats on local file systems, network file systems, cloud object stores, and Hadoop. - -Here is a full list of [supported data descriptions and APIs](https://kedro.readthedocs.io/en/stable/kedro.extras.datasets.html). - -## How can I create my own `AbstractDataset` implementation? - - -Take a look at our [instructions on how to create your own `AbstractDataset` implementation](https://kedro.readthedocs.io/en/stable/extend_kedro/custom_datasets.html). diff --git a/kedro/extras/datasets/__init__.py b/kedro/extras/datasets/__init__.py deleted file mode 100644 index 3eec3e3fe1..0000000000 --- a/kedro/extras/datasets/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -"""``kedro.extras.datasets`` is where you can find all of Kedro's data connectors. -These data connectors are implementations of the ``AbstractDataset``. - -.. warning:: - - ``kedro.extras.datasets`` is deprecated and will be removed in Kedro 0.19. - Refer to :py:mod:`kedro_datasets` for the documentation, and - install ``kedro-datasets`` to avoid breakage by running ``pip install kedro-datasets``. - -""" - -from warnings import warn as _warn - -_warn( - "`kedro.extras.datasets` is deprecated and will be removed in Kedro 0.19, " - "install `kedro-datasets` instead by running `pip install kedro-datasets`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/kedro/extras/datasets/api/__init__.py b/kedro/extras/datasets/api/__init__.py deleted file mode 100644 index ccd799b2c9..0000000000 --- a/kedro/extras/datasets/api/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""``APIDataSet`` loads the data from HTTP(S) APIs -and returns them into either as string or json Dict. -It uses the python requests library: https://requests.readthedocs.io/en/latest/ -""" - -__all__ = ["APIDataSet"] - -from contextlib import suppress - -with suppress(ImportError): - from .api_dataset import APIDataSet diff --git a/kedro/extras/datasets/api/api_dataset.py b/kedro/extras/datasets/api/api_dataset.py deleted file mode 100644 index 0e79f9aad2..0000000000 --- a/kedro/extras/datasets/api/api_dataset.py +++ /dev/null @@ -1,141 +0,0 @@ -"""``APIDataSet`` loads the data from HTTP(S) APIs. -It uses the python requests library: https://requests.readthedocs.io/en/latest/ -""" -from typing import Any, Dict, Iterable, List, NoReturn, Union - -import requests -from requests.auth import AuthBase - -from kedro.io.core import AbstractDataset, DatasetError - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class APIDataSet(AbstractDataset[None, requests.Response]): - """``APIDataSet`` loads the data from HTTP(S) APIs. - It uses the python requests library: https://requests.readthedocs.io/en/latest/ - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - usda: - type: api.APIDataSet - url: https://quickstats.nass.usda.gov - params: - key: SOME_TOKEN, - format: JSON, - commodity_desc: CORN, - statisticcat_des: YIELD, - agg_level_desc: STATE, - year: 2000 - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.api import APIDataSet - >>> - >>> - >>> data_set = APIDataSet( - >>> url="https://quickstats.nass.usda.gov", - >>> params={ - >>> "key": "SOME_TOKEN", - >>> "format": "JSON", - >>> "commodity_desc": "CORN", - >>> "statisticcat_des": "YIELD", - >>> "agg_level_desc": "STATE", - >>> "year": 2000 - >>> } - >>> ) - >>> data = data_set.load() - """ - - def __init__( # noqa: too-many-arguments - self, - url: str, - method: str = "GET", - data: Any = None, - params: Dict[str, Any] = None, - headers: Dict[str, Any] = None, - auth: Union[Iterable[str], AuthBase] = None, - json: Union[List, Dict[str, Any]] = None, - timeout: int = 60, - credentials: Union[Iterable[str], AuthBase] = None, - ) -> None: - """Creates a new instance of ``APIDataSet`` to fetch data from an API endpoint. - - Args: - url: The API URL endpoint. - method: The Method of the request, GET, POST, PUT, DELETE, HEAD, etc... - data: The request payload, used for POST, PUT, etc requests - https://requests.readthedocs.io/en/latest/user/quickstart/#more-complicated-post-requests - params: The url parameters of the API. - https://requests.readthedocs.io/en/latest/user/quickstart/#passing-parameters-in-urls - headers: The HTTP headers. - https://requests.readthedocs.io/en/latest/user/quickstart/#custom-headers - auth: Anything ``requests`` accepts. Normally it's either ``('login', 'password')``, - or ``AuthBase``, ``HTTPBasicAuth`` instance for more complex cases. Any - iterable will be cast to a tuple. - json: The request payload, used for POST, PUT, etc requests, passed in - to the json kwarg in the requests object. - https://requests.readthedocs.io/en/latest/user/quickstart/#more-complicated-post-requests - timeout: The wait time in seconds for a response, defaults to 1 minute. - https://requests.readthedocs.io/en/latest/user/quickstart/#timeouts - credentials: same as ``auth``. Allows specifying ``auth`` secrets in - credentials.yml. - - Raises: - ValueError: if both ``credentials`` and ``auth`` are specified. - """ - super().__init__() - - if credentials is not None and auth is not None: - raise ValueError("Cannot specify both auth and credentials.") - - auth = credentials or auth - - if isinstance(auth, Iterable): - auth = tuple(auth) - - self._request_args: Dict[str, Any] = { - "url": url, - "method": method, - "data": data, - "params": params, - "headers": headers, - "auth": auth, - "json": json, - "timeout": timeout, - } - - def _describe(self) -> Dict[str, Any]: - return {**self._request_args} - - def _execute_request(self) -> requests.Response: - try: - response = requests.request(**self._request_args) - response.raise_for_status() - except requests.exceptions.HTTPError as exc: - raise DatasetError("Failed to fetch data", exc) from exc - except OSError as exc: - raise DatasetError("Failed to connect to the remote server") from exc - - return response - - def _load(self) -> requests.Response: - return self._execute_request() - - def _save(self, data: None) -> NoReturn: - raise DatasetError(f"{self.__class__.__name__} is a read only data set type") - - def _exists(self) -> bool: - response = self._execute_request() - - return response.ok diff --git a/kedro/extras/datasets/biosequence/__init__.py b/kedro/extras/datasets/biosequence/__init__.py deleted file mode 100644 index d806e3ca33..0000000000 --- a/kedro/extras/datasets/biosequence/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""``AbstractDataset`` implementation to read/write from/to a sequence file.""" - -__all__ = ["BioSequenceDataSet"] - -from contextlib import suppress - -with suppress(ImportError): - from .biosequence_dataset import BioSequenceDataSet diff --git a/kedro/extras/datasets/biosequence/biosequence_dataset.py b/kedro/extras/datasets/biosequence/biosequence_dataset.py deleted file mode 100644 index ac0770aa68..0000000000 --- a/kedro/extras/datasets/biosequence/biosequence_dataset.py +++ /dev/null @@ -1,136 +0,0 @@ -"""BioSequenceDataSet loads and saves data to/from bio-sequence objects to -file. -""" -from copy import deepcopy -from pathlib import PurePosixPath -from typing import Any, Dict, List - -import fsspec -from Bio import SeqIO - -from kedro.io.core import AbstractDataset, get_filepath_str, get_protocol_and_path - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class BioSequenceDataSet(AbstractDataset[List, List]): - r"""``BioSequenceDataSet`` loads and saves data to a sequence file. - - Example: - :: - - >>> from kedro.extras.datasets.biosequence import BioSequenceDataSet - >>> from io import StringIO - >>> from Bio import SeqIO - >>> - >>> data = ">Alpha\nACCGGATGTA\n>Beta\nAGGCTCGGTTA\n" - >>> raw_data = [] - >>> for record in SeqIO.parse(StringIO(data), "fasta"): - >>> raw_data.append(record) - >>> - >>> data_set = BioSequenceDataSet(filepath="ls_orchid.fasta", - >>> load_args={"format": "fasta"}, - >>> save_args={"format": "fasta"}) - >>> data_set.save(raw_data) - >>> sequence_list = data_set.load() - >>> - >>> assert raw_data[0].id == sequence_list[0].id - >>> assert raw_data[0].seq == sequence_list[0].seq - - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """ - Creates a new instance of ``BioSequenceDataSet`` pointing - to a concrete filepath. - - Args: - filepath: Filepath in POSIX format to sequence file prefixed with a protocol like - `s3://`. If prefix is not provided, `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - load_args: Options for parsing sequence files by Biopython ``SeqIO.parse()``. - save_args: file format supported by Biopython ``SeqIO.write()``. - E.g. `{"format": "fasta"}`. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested keys - `open_args_load` and `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set to `r` when loading - and to `w` when saving. - - Note: Here you can find all supported file formats: https://biopython.org/wiki/SeqIO - """ - - _fs_args = deepcopy(fs_args) or {} - _fs_open_args_load = _fs_args.pop("open_args_load", {}) - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath) - - self._filepath = PurePosixPath(path) - self._protocol = protocol - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - # Handle default load and save arguments - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - _fs_open_args_load.setdefault("mode", "r") - _fs_open_args_save.setdefault("mode", "w") - self._fs_open_args_load = _fs_open_args_load - self._fs_open_args_save = _fs_open_args_save - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._protocol, - "load_args": self._load_args, - "save_args": self._save_args, - } - - def _load(self) -> List: - load_path = get_filepath_str(self._filepath, self._protocol) - with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: - return list(SeqIO.parse(handle=fs_file, **self._load_args)) - - def _save(self, data: List) -> None: - save_path = get_filepath_str(self._filepath, self._protocol) - - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - SeqIO.write(data, handle=fs_file, **self._save_args) - - def _exists(self) -> bool: - load_path = get_filepath_str(self._filepath, self._protocol) - return self._fs.exists(load_path) - - def _release(self) -> None: - self.invalidate_cache() - - def invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/dask/__init__.py b/kedro/extras/datasets/dask/__init__.py deleted file mode 100644 index d93bf4c63f..0000000000 --- a/kedro/extras/datasets/dask/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Provides I/O modules using dask dataframe.""" - -__all__ = ["ParquetDataSet"] - -from contextlib import suppress - -with suppress(ImportError): - from .parquet_dataset import ParquetDataSet diff --git a/kedro/extras/datasets/dask/parquet_dataset.py b/kedro/extras/datasets/dask/parquet_dataset.py deleted file mode 100644 index 21fcfe25b0..0000000000 --- a/kedro/extras/datasets/dask/parquet_dataset.py +++ /dev/null @@ -1,210 +0,0 @@ -"""``ParquetDataSet`` is a data set used to load and save data to parquet files using Dask -dataframe""" - -from copy import deepcopy -from typing import Any, Dict - -import dask.dataframe as dd -import fsspec -import triad - -from kedro.io.core import AbstractDataset, get_protocol_and_path - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class ParquetDataSet(AbstractDataset[dd.DataFrame, dd.DataFrame]): - """``ParquetDataSet`` loads and saves data to parquet file(s). It uses Dask - remote data services to handle the corresponding load and save operations: - https://docs.dask.org/en/latest/how-to/connect-to-remote-data.html - - Example usage for the - `YAML API `_: - - .. code-block:: yaml - - cars: - type: dask.ParquetDataSet - filepath: s3://bucket_name/path/to/folder - save_args: - compression: GZIP - credentials: - client_kwargs: - aws_access_key_id: YOUR_KEY - aws_secret_access_key: YOUR_SECRET - - Example usage for the - `Python API `_: - :: - - - >>> from kedro.extras.datasets.dask import ParquetDataSet - >>> import pandas as pd - >>> import dask.dataframe as dd - >>> - >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [[5, 6], [7, 8]]}) - >>> ddf = dd.from_pandas(data, npartitions=2) - >>> - >>> data_set = ParquetDataSet( - >>> filepath="s3://bucket_name/path/to/folder", - >>> credentials={ - >>> 'client_kwargs':{ - >>> 'aws_access_key_id': 'YOUR_KEY', - >>> 'aws_secret_access_key': 'YOUR SECRET', - >>> } - >>> }, - >>> save_args={"compression": "GZIP"} - >>> ) - >>> data_set.save(ddf) - >>> reloaded = data_set.load() - >>> - >>> assert ddf.compute().equals(reloaded.compute()) - - The output schema can also be explicitly specified using - `Triad `_. - This is processed to map specific columns to - `PyArrow field types `_ or schema. For instance: - - .. code-block:: yaml - - parquet_dataset: - type: dask.ParquetDataSet - filepath: "s3://bucket_name/path/to/folder" - credentials: - client_kwargs: - aws_access_key_id: YOUR_KEY - aws_secret_access_key: "YOUR SECRET" - save_args: - compression: GZIP - schema: - col1: [int32] - col2: [int32] - col3: [[int32]] - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {"write_index": False} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``ParquetDataSet`` pointing to concrete - parquet files. - - Args: - filepath: Filepath in POSIX format to a parquet file - parquet collection or the directory of a multipart parquet. - load_args: Additional loading options `dask.dataframe.read_parquet`: - https://docs.dask.org/en/latest/generated/dask.dataframe.read_parquet.html - save_args: Additional saving options for `dask.dataframe.to_parquet`: - https://docs.dask.org/en/latest/generated/dask.dataframe.to_parquet.html - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Optional parameters to the backend file system driver: - https://docs.dask.org/en/latest/how-to/connect-to-remote-data.html#optional-parameters - """ - self._filepath = filepath - self._fs_args = deepcopy(fs_args) or {} - self._credentials = deepcopy(credentials) or {} - - # Handle default load and save arguments - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - @property - def fs_args(self) -> Dict[str, Any]: - """Property of optional file system parameters. - - Returns: - A dictionary of backend file system parameters, including credentials. - """ - fs_args = deepcopy(self._fs_args) - fs_args.update(self._credentials) - return fs_args - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "load_args": self._load_args, - "save_args": self._save_args, - } - - def _load(self) -> dd.DataFrame: - return dd.read_parquet( - self._filepath, storage_options=self.fs_args, **self._load_args - ) - - def _save(self, data: dd.DataFrame) -> None: - self._process_schema() - data.to_parquet(self._filepath, storage_options=self.fs_args, **self._save_args) - - def _process_schema(self) -> None: - """This method processes the schema in the catalog.yml or the API, if provided. - This assumes that the schema is specified using Triad's grammar for - schema definition. - - When the value of the `schema` variable is a string, it is assumed that - it corresponds to the full schema specification for the data. - - Alternatively, if the `schema` is specified as a dictionary, then only the - columns that are specified will be strictly mapped to a field type. The other - unspecified columns, if present, will be inferred from the data. - - This method converts the Triad-parsed schema into a pyarrow schema. - The output directly supports Dask's specifications for providing a schema - when saving to a parquet file. - - Note that if a `pa.Schema` object is passed directly in the `schema` argument, no - processing will be done. Additionally, the behavior when passing a `pa.Schema` - object is assumed to be consistent with how Dask sees it. That is, it should fully - define the schema for all fields. - """ - schema = self._save_args.get("schema") - - if isinstance(schema, dict): - # The schema may contain values of different types, e.g., pa.DataType, Python types, - # strings, etc. The latter requires a transformation, then we use triad handle all - # other value types. - - # Create a schema from values that triad can handle directly - triad_schema = triad.Schema( - {k: v for k, v in schema.items() if not isinstance(v, str)} - ) - - # Handle the schema keys that are represented as string and add them to the triad schema - triad_schema.update( - triad.Schema( - ",".join( - [f"{k}:{v}" for k, v in schema.items() if isinstance(v, str)] - ) - ) - ) - - # Update the schema argument with the normalized schema - self._save_args["schema"].update( - {col: field.type for col, field in triad_schema.items()} - ) - - elif isinstance(schema, str): - self._save_args["schema"] = triad.Schema(schema).pyarrow_schema - - def _exists(self) -> bool: - protocol = get_protocol_and_path(self._filepath)[0] - file_system = fsspec.filesystem(protocol=protocol, **self.fs_args) - return file_system.exists(self._filepath) diff --git a/kedro/extras/datasets/email/__init__.py b/kedro/extras/datasets/email/__init__.py deleted file mode 100644 index ba7873cbf2..0000000000 --- a/kedro/extras/datasets/email/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""``AbstractDataset`` implementations for managing email messages.""" - -__all__ = ["EmailMessageDataSet"] - -from contextlib import suppress - -with suppress(ImportError): - from .message_dataset import EmailMessageDataSet diff --git a/kedro/extras/datasets/email/message_dataset.py b/kedro/extras/datasets/email/message_dataset.py deleted file mode 100644 index 695d93cbbe..0000000000 --- a/kedro/extras/datasets/email/message_dataset.py +++ /dev/null @@ -1,187 +0,0 @@ -"""``EmailMessageDataSet`` loads/saves an email message from/to a file -using an underlying filesystem (e.g.: local, S3, GCS). It uses the -``email`` package in the standard library to manage email messages. -""" -from copy import deepcopy -from email.generator import Generator -from email.message import Message -from email.parser import Parser -from email.policy import default -from pathlib import PurePosixPath -from typing import Any, Dict - -import fsspec - -from kedro.io.core import ( - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class EmailMessageDataSet( - AbstractVersionedDataset[Message, Message] -): # pylint: disable=too-many-instance-attributes - """``EmailMessageDataSet`` loads/saves an email message from/to a file - using an underlying filesystem (e.g.: local, S3, GCS). It uses the - ``email`` package in the standard library to manage email messages. - - Note that ``EmailMessageDataSet`` doesn't handle sending email messages. - - Example: - :: - - >>> from email.message import EmailMessage - >>> - >>> from kedro.extras.datasets.email import EmailMessageDataSet - >>> - >>> string_to_write = "what would you do if you were invisable for one day????" - >>> - >>> # Create a text/plain message - >>> msg = EmailMessage() - >>> msg.set_content(string_to_write) - >>> msg["Subject"] = "invisibility" - >>> msg["From"] = '"sin studly17"' - >>> msg["To"] = '"strong bad"' - >>> - >>> data_set = EmailMessageDataSet(filepath="test") - >>> data_set.save(msg) - >>> reloaded = data_set.load() - >>> assert msg.__dict__ == reloaded.__dict__ - - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``EmailMessageDataSet`` pointing to a concrete text file - on a specific filesystem. - - Args: - filepath: Filepath in POSIX format to a text file prefixed with a protocol like `s3://`. - If prefix is not provided, `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - Note: `http(s)` doesn't support versioning. - load_args: ``email`` options for parsing email messages (arguments passed - into ``email.parser.Parser.parse``). Here you can find all available arguments: - https://docs.python.org/3/library/email.parser.html#email.parser.Parser.parse - If you would like to specify options for the `Parser`, - you can include them under the "parser" key. Here you can - find all available arguments: - https://docs.python.org/3/library/email.parser.html#email.parser.Parser - All defaults are preserved, but "policy", which is set to ``email.policy.default``. - save_args: ``email`` options for generating MIME documents (arguments passed into - ``email.generator.Generator.flatten``). Here you can find all available arguments: - https://docs.python.org/3/library/email.generator.html#email.generator.Generator.flatten - If you would like to specify options for the `Generator`, - you can include them under the "generator" key. Here you can - find all available arguments: - https://docs.python.org/3/library/email.generator.html#email.generator.Generator - All defaults are preserved. - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested keys - `open_args_load` and `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set to `r` when loading - and to `w` when saving. - """ - _fs_args = deepcopy(fs_args) or {} - _fs_open_args_load = _fs_args.pop("open_args_load", {}) - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath, version) - - self._protocol = protocol - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - # Handle default load arguments - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._parser_args = self._load_args.pop("parser", {"policy": default}) - - # Handle default save arguments - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - self._generator_args = self._save_args.pop("generator", {}) - - _fs_open_args_load.setdefault("mode", "r") - _fs_open_args_save.setdefault("mode", "w") - self._fs_open_args_load = _fs_open_args_load - self._fs_open_args_save = _fs_open_args_save - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._protocol, - "load_args": self._load_args, - "parser_args": self._parser_args, - "save_args": self._save_args, - "generator_args": self._generator_args, - "version": self._version, - } - - def _load(self) -> Message: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - - with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: - return Parser(**self._parser_args).parse(fs_file, **self._load_args) - - def _save(self, data: Message) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - Generator(fs_file, **self._generator_args).flatten(data, **self._save_args) - - self._invalidate_cache() - - def _exists(self) -> bool: - try: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DatasetError: - return False - - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/geopandas/README.md b/kedro/extras/datasets/geopandas/README.md deleted file mode 100644 index d7c1a3c96a..0000000000 --- a/kedro/extras/datasets/geopandas/README.md +++ /dev/null @@ -1,31 +0,0 @@ -# GeoJSON - -``GeoJSONDataSet`` loads and saves data to a local yaml file using ``geopandas``. -See [geopandas.GeoDataFrame](http://geopandas.org/reference/geopandas.GeoDataFrame.html) for details. - -#### Example use: - -```python -import geopandas as gpd -from shapely.geometry import Point -from kedro.extras.datasets.geopandas import GeoJSONDataSet - -data = gpd.GeoDataFrame( - {"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}, - geometry=[Point(1, 1), Point(2, 4)], -) -data_set = GeoJSONDataSet(filepath="test.geojson") -data_set.save(data) -reloaded = data_set.load() -assert data.equals(reloaded) -``` - -#### Example catalog.yml: - -```yaml -example_geojson_data: - type: geopandas.GeoJSONDataSet - filepath: data/08_reporting/test.geojson -``` - -Contributed by (Luis Blanche)[https://github.com/lblanche]. diff --git a/kedro/extras/datasets/geopandas/__init__.py b/kedro/extras/datasets/geopandas/__init__.py deleted file mode 100644 index bee7462a83..0000000000 --- a/kedro/extras/datasets/geopandas/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""``GeoJSONDataSet`` is an ``AbstractVersionedDataset`` to save and load GeoJSON files. -""" -__all__ = ["GeoJSONDataSet"] - -from contextlib import suppress - -with suppress(ImportError): - from .geojson_dataset import GeoJSONDataSet diff --git a/kedro/extras/datasets/geopandas/geojson_dataset.py b/kedro/extras/datasets/geopandas/geojson_dataset.py deleted file mode 100644 index 5beba29d57..0000000000 --- a/kedro/extras/datasets/geopandas/geojson_dataset.py +++ /dev/null @@ -1,155 +0,0 @@ -"""GeoJSONDataSet loads and saves data to a local geojson file. The -underlying functionality is supported by geopandas, so it supports all -allowed geopandas (pandas) options for loading and saving geosjon files. -""" -import copy -from pathlib import PurePosixPath -from typing import Any, Dict, Union - -import fsspec -import geopandas as gpd - -from kedro.io.core import ( - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class GeoJSONDataSet( - AbstractVersionedDataset[ - gpd.GeoDataFrame, Union[gpd.GeoDataFrame, Dict[str, gpd.GeoDataFrame]] - ] -): - """``GeoJSONDataSet`` loads/saves data to a GeoJSON file using an underlying filesystem - (eg: local, S3, GCS). - The underlying functionality is supported by geopandas, so it supports all - allowed geopandas (pandas) options for loading and saving GeoJSON files. - - Example: - :: - - >>> import geopandas as gpd - >>> from shapely.geometry import Point - >>> from kedro.extras.datasets.geopandas import GeoJSONDataSet - >>> - >>> data = gpd.GeoDataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}, geometry=[Point(1,1), Point(2,4)]) - >>> data_set = GeoJSONDataSet(filepath="test.geojson", save_args=None) - >>> data_set.save(data) - >>> reloaded = data_set.load() - >>> - >>> assert data.equals(reloaded) - - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {"driver": "GeoJSON"} - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``GeoJSONDataSet`` pointing to a concrete GeoJSON file - on a specific filesystem fsspec. - - Args: - - filepath: Filepath in POSIX format to a GeoJSON file prefixed with a protocol like - `s3://`. If prefix is not provided `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - Note: `http(s)` doesn't support versioning. - load_args: GeoPandas options for loading GeoJSON files. - Here you can find all available arguments: - https://geopandas.org/en/stable/docs/reference/api/geopandas.read_file.html - save_args: GeoPandas options for saving geojson files. - Here you can find all available arguments: - https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.to_file.html - The default_save_arg driver is 'GeoJSON', all others preserved. - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - credentials: credentials required to access the underlying filesystem. - Eg. for ``GCFileSystem`` it would look like `{'token': None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested keys - `open_args_load` and `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set to `wb` when saving. - """ - _fs_args = copy.deepcopy(fs_args) or {} - _fs_open_args_load = _fs_args.pop("open_args_load", {}) - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _credentials = copy.deepcopy(credentials) or {} - protocol, path = get_protocol_and_path(filepath, version) - self._protocol = protocol - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - - self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - _fs_open_args_save.setdefault("mode", "wb") - self._fs_open_args_load = _fs_open_args_load - self._fs_open_args_save = _fs_open_args_save - - def _load(self) -> Union[gpd.GeoDataFrame, Dict[str, gpd.GeoDataFrame]]: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: - return gpd.read_file(fs_file, **self._load_args) - - def _save(self, data: gpd.GeoDataFrame) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - data.to_file(fs_file, **self._save_args) - self.invalidate_cache() - - def _exists(self) -> bool: - try: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DatasetError: - return False - return self._fs.exists(load_path) - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._load_args, - "save_args": self._save_args, - "version": self._version, - } - - def _release(self) -> None: - self.invalidate_cache() - - def invalidate_cache(self) -> None: - """Invalidate underlying filesystem cache.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/holoviews/__init__.py b/kedro/extras/datasets/holoviews/__init__.py deleted file mode 100644 index f50db9b823..0000000000 --- a/kedro/extras/datasets/holoviews/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""``AbstractDataset`` implementation to save Holoviews objects as image files.""" - -__all__ = ["HoloviewsWriter"] - -from contextlib import suppress - -with suppress(ImportError): - from .holoviews_writer import HoloviewsWriter diff --git a/kedro/extras/datasets/holoviews/holoviews_writer.py b/kedro/extras/datasets/holoviews/holoviews_writer.py deleted file mode 100644 index 34daeb1769..0000000000 --- a/kedro/extras/datasets/holoviews/holoviews_writer.py +++ /dev/null @@ -1,136 +0,0 @@ -"""``HoloviewsWriter`` saves Holoviews objects as image file(s) to an underlying -filesystem (e.g. local, S3, GCS).""" - -import io -from copy import deepcopy -from pathlib import PurePosixPath -from typing import Any, Dict, NoReturn, TypeVar - -import fsspec -import holoviews as hv - -from kedro.io.core import ( - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - -# HoloViews to be passed in `hv.save()` -HoloViews = TypeVar("HoloViews") - - -class HoloviewsWriter(AbstractVersionedDataset[HoloViews, NoReturn]): - """``HoloviewsWriter`` saves Holoviews objects to image file(s) in an underlying - filesystem (e.g. local, S3, GCS). - - Example: - :: - - >>> import holoviews as hv - >>> from kedro.extras.datasets.holoviews import HoloviewsWriter - >>> - >>> curve = hv.Curve(range(10)) - >>> holoviews_writer = HoloviewsWriter("/tmp/holoviews") - >>> - >>> holoviews_writer.save(curve) - - """ - - DEFAULT_SAVE_ARGS = {"fmt": "png"} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - fs_args: Dict[str, Any] = None, - credentials: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - ) -> None: - """Creates a new instance of ``HoloviewsWriter``. - - Args: - filepath: Filepath in POSIX format to a text file prefixed with a protocol like `s3://`. - If prefix is not provided, `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - Note: `http(s)` doesn't support versioning. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested key `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set to `wb` when saving. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``S3FileSystem`` it should look like: - `{'key': '', 'secret': ''}}` - save_args: Extra save args passed to `holoviews.save()`. See - https://holoviews.org/reference_manual/holoviews.util.html#holoviews.util.save - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - """ - _credentials = deepcopy(credentials) or {} - _fs_args = deepcopy(fs_args) or {} - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _fs_open_args_save.setdefault("mode", "wb") - - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - self._fs_open_args_save = _fs_open_args_save - - # Handle default save arguments - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._protocol, - "save_args": self._save_args, - "version": self._version, - } - - def _load(self) -> NoReturn: - raise DatasetError(f"Loading not supported for '{self.__class__.__name__}'") - - def _save(self, data: HoloViews) -> None: - bytes_buffer = io.BytesIO() - hv.save(data, bytes_buffer, **self._save_args) - - save_path = get_filepath_str(self._get_save_path(), self._protocol) - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - fs_file.write(bytes_buffer.getvalue()) - - self._invalidate_cache() - - def _exists(self) -> bool: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/json/__init__.py b/kedro/extras/datasets/json/__init__.py deleted file mode 100644 index 887f7cd72f..0000000000 --- a/kedro/extras/datasets/json/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""``AbstractDataset`` implementation to load/save data from/to a JSON file.""" - -__all__ = ["JSONDataSet"] - -from contextlib import suppress - -with suppress(ImportError): - from .json_dataset import JSONDataSet diff --git a/kedro/extras/datasets/json/json_dataset.py b/kedro/extras/datasets/json/json_dataset.py deleted file mode 100644 index f5907cc162..0000000000 --- a/kedro/extras/datasets/json/json_dataset.py +++ /dev/null @@ -1,159 +0,0 @@ -"""``JSONDataSet`` loads/saves data from/to a JSON file using an underlying -filesystem (e.g.: local, S3, GCS). It uses native json to handle the JSON file. -""" -import json -from copy import deepcopy -from pathlib import PurePosixPath -from typing import Any, Dict - -import fsspec - -from kedro.io.core import ( - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class JSONDataSet(AbstractVersionedDataset[Any, Any]): - """``JSONDataSet`` loads/saves data from/to a JSON file using an underlying - filesystem (e.g.: local, S3, GCS). It uses native json to handle the JSON file. - - Example usage for the - `YAML API `_: - - .. code-block:: yaml - - cars: - type: json.JSONDataSet - filepath: gcs://your_bucket/cars.json - fs_args: - project: my-project - credentials: my_gcp_credentials - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.json import JSONDataSet - >>> - >>> data = {'col1': [1, 2], 'col2': [4, 5], 'col3': [5, 6]} - >>> - >>> data_set = JSONDataSet(filepath="test.json") - >>> data_set.save(data) - >>> reloaded = data_set.load() - >>> assert data == reloaded - - """ - - DEFAULT_SAVE_ARGS = {"indent": 2} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``JSONDataSet`` pointing to a concrete JSON file - on a specific filesystem. - - Args: - filepath: Filepath in POSIX format to a JSON file prefixed with a protocol like `s3://`. - If prefix is not provided, `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - Note: `http(s)` doesn't support versioning. - save_args: json options for saving JSON files (arguments passed - into ```json.dump``). Here you can find all available arguments: - https://docs.python.org/3/library/json.html - All defaults are preserved, but "default_flow_style", which is set to False. - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested keys - `open_args_load` and `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set to `r` when loading - and to `w` when saving. - """ - _fs_args = deepcopy(fs_args) or {} - _fs_open_args_load = _fs_args.pop("open_args_load", {}) - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath, version) - - self._protocol = protocol - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - # Handle default save arguments - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - _fs_open_args_save.setdefault("mode", "w") - self._fs_open_args_load = _fs_open_args_load - self._fs_open_args_save = _fs_open_args_save - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._protocol, - "save_args": self._save_args, - "version": self._version, - } - - def _load(self) -> Any: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - - with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: - return json.load(fs_file) - - def _save(self, data: Any) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - json.dump(data, fs_file, **self._save_args) - - self._invalidate_cache() - - def _exists(self) -> bool: - try: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DatasetError: - return False - - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/matplotlib/__init__.py b/kedro/extras/datasets/matplotlib/__init__.py deleted file mode 100644 index eabd8fc517..0000000000 --- a/kedro/extras/datasets/matplotlib/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""``AbstractDataset`` implementation to save matplotlib objects as image files.""" - -__all__ = ["MatplotlibWriter"] - -from contextlib import suppress - -with suppress(ImportError): - from .matplotlib_writer import MatplotlibWriter diff --git a/kedro/extras/datasets/matplotlib/matplotlib_writer.py b/kedro/extras/datasets/matplotlib/matplotlib_writer.py deleted file mode 100644 index 6c29b4d5ba..0000000000 --- a/kedro/extras/datasets/matplotlib/matplotlib_writer.py +++ /dev/null @@ -1,238 +0,0 @@ -"""``MatplotlibWriter`` saves one or more Matplotlib objects as image -files to an underlying filesystem (e.g. local, S3, GCS).""" - -import io -from copy import deepcopy -from pathlib import PurePosixPath -from typing import Any, Dict, List, NoReturn, Union -from warnings import warn - -import fsspec -import matplotlib.pyplot as plt - -from kedro.io.core import ( - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class MatplotlibWriter( - AbstractVersionedDataset[ - Union[plt.figure, List[plt.figure], Dict[str, plt.figure]], NoReturn - ] -): - """``MatplotlibWriter`` saves one or more Matplotlib objects as - image files to an underlying filesystem (e.g. local, S3, GCS). - - Example usage for the - `YAML API `_: - - .. code-block:: yaml - - output_plot: - type: matplotlib.MatplotlibWriter - filepath: data/08_reporting/output_plot.png - save_args: - format: png - - Example usage for the - `Python API `_: - :: - - >>> import matplotlib.pyplot as plt - >>> from kedro.extras.datasets.matplotlib import MatplotlibWriter - >>> - >>> fig = plt.figure() - >>> plt.plot([1, 2, 3]) - >>> plot_writer = MatplotlibWriter( - >>> filepath="data/08_reporting/output_plot.png" - >>> ) - >>> plt.close() - >>> plot_writer.save(fig) - - Example saving a plot as a PDF file: - :: - - >>> import matplotlib.pyplot as plt - >>> from kedro.extras.datasets.matplotlib import MatplotlibWriter - >>> - >>> fig = plt.figure() - >>> plt.plot([1, 2, 3]) - >>> pdf_plot_writer = MatplotlibWriter( - >>> filepath="data/08_reporting/output_plot.pdf", - >>> save_args={"format": "pdf"}, - >>> ) - >>> plt.close() - >>> pdf_plot_writer.save(fig) - - Example saving multiple plots in a folder, using a dictionary: - :: - - >>> import matplotlib.pyplot as plt - >>> from kedro.extras.datasets.matplotlib import MatplotlibWriter - >>> - >>> plots_dict = {} - >>> for colour in ["blue", "green", "red"]: - >>> plots_dict[f"{colour}.png"] = plt.figure() - >>> plt.plot([1, 2, 3], color=colour) - >>> - >>> plt.close("all") - >>> dict_plot_writer = MatplotlibWriter( - >>> filepath="data/08_reporting/plots" - >>> ) - >>> dict_plot_writer.save(plots_dict) - - Example saving multiple plots in a folder, using a list: - :: - - >>> import matplotlib.pyplot as plt - >>> from kedro.extras.datasets.matplotlib import MatplotlibWriter - >>> - >>> plots_list = [] - >>> for i in range(5): - >>> plots_list.append(plt.figure()) - >>> plt.plot([i, i + 1, i + 2]) - >>> plt.close("all") - >>> list_plot_writer = MatplotlibWriter( - >>> filepath="data/08_reporting/plots" - >>> ) - >>> list_plot_writer.save(plots_list) - - """ - - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - fs_args: Dict[str, Any] = None, - credentials: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - overwrite: bool = False, - ) -> None: - """Creates a new instance of ``MatplotlibWriter``. - - Args: - filepath: Filepath in POSIX format to save Matplotlib objects to, prefixed with a - protocol like `s3://`. If prefix is not provided, `file` protocol (local filesystem) - will be used. The prefix should be any protocol supported by ``fsspec``. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested key `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set to `wb` when saving. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``S3FileSystem`` it should look like: - `{'key': '', 'secret': ''}}` - save_args: Save args passed to `plt.savefig`. See - https://matplotlib.org/api/_as_gen/matplotlib.pyplot.savefig.html - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - overwrite: If True, any existing image files will be removed. - Only relevant when saving multiple Matplotlib objects at - once. - """ - _credentials = deepcopy(credentials) or {} - _fs_args = deepcopy(fs_args) or {} - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _fs_open_args_save.setdefault("mode", "wb") - - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - self._fs_open_args_save = _fs_open_args_save - - # Handle default save arguments - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - if overwrite and version is not None: - warn( - "Setting 'overwrite=True' is ineffective if versioning " - "is enabled, since the versioned path must not already " - "exist; overriding flag with 'overwrite=False' instead." - ) - overwrite = False - self._overwrite = overwrite - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._protocol, - "save_args": self._save_args, - "version": self._version, - } - - def _load(self) -> NoReturn: - raise DatasetError(f"Loading not supported for '{self.__class__.__name__}'") - - def _save( - self, data: Union[plt.figure, List[plt.figure], Dict[str, plt.figure]] - ) -> None: - save_path = self._get_save_path() - - if isinstance(data, (list, dict)) and self._overwrite and self._exists(): - self._fs.rm(get_filepath_str(save_path, self._protocol), recursive=True) - - if isinstance(data, list): - for index, plot in enumerate(data): - full_key_path = get_filepath_str( - save_path / f"{index}.png", self._protocol - ) - self._save_to_fs(full_key_path=full_key_path, plot=plot) - elif isinstance(data, dict): - for plot_name, plot in data.items(): - full_key_path = get_filepath_str(save_path / plot_name, self._protocol) - self._save_to_fs(full_key_path=full_key_path, plot=plot) - else: - full_key_path = get_filepath_str(save_path, self._protocol) - self._save_to_fs(full_key_path=full_key_path, plot=data) - - plt.close("all") - - self._invalidate_cache() - - def _save_to_fs(self, full_key_path: str, plot: plt.figure): - bytes_buffer = io.BytesIO() - plot.savefig(bytes_buffer, **self._save_args) - - with self._fs.open(full_key_path, **self._fs_open_args_save) as fs_file: - fs_file.write(bytes_buffer.getvalue()) - - def _exists(self) -> bool: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/networkx/__init__.py b/kedro/extras/datasets/networkx/__init__.py deleted file mode 100644 index ece1b98f9c..0000000000 --- a/kedro/extras/datasets/networkx/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -"""``AbstractDataset`` implementation to save and load NetworkX graphs in JSON -, GraphML and GML formats using ``NetworkX``.""" - -__all__ = ["GMLDataSet", "GraphMLDataSet", "JSONDataSet"] - -from contextlib import suppress - -with suppress(ImportError): - from .gml_dataset import GMLDataSet - -with suppress(ImportError): - from .graphml_dataset import GraphMLDataSet - -with suppress(ImportError): - from .json_dataset import JSONDataSet diff --git a/kedro/extras/datasets/networkx/gml_dataset.py b/kedro/extras/datasets/networkx/gml_dataset.py deleted file mode 100644 index a56ddbe7ba..0000000000 --- a/kedro/extras/datasets/networkx/gml_dataset.py +++ /dev/null @@ -1,143 +0,0 @@ -"""NetworkX ``GMLDataSet`` loads and saves graphs to a graph modelling language (GML) -file using an underlying filesystem (e.g.: local, S3, GCS). ``NetworkX`` is used to -create GML data. -""" - -from copy import deepcopy -from pathlib import PurePosixPath -from typing import Any, Dict - -import fsspec -import networkx - -from kedro.io.core import ( - AbstractVersionedDataset, - Version, - get_filepath_str, - get_protocol_and_path, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class GMLDataSet(AbstractVersionedDataset[networkx.Graph, networkx.Graph]): - """``GMLDataSet`` loads and saves graphs to a GML file using an - underlying filesystem (e.g.: local, S3, GCS). ``NetworkX`` is used to - create GML data. - See https://networkx.org/documentation/stable/tutorial.html for details. - - Example: - :: - - >>> from kedro.extras.datasets.networkx import GMLDataSet - >>> import networkx as nx - >>> graph = nx.complete_graph(100) - >>> graph_dataset = GMLDataSet(filepath="test.gml") - >>> graph_dataset.save(graph) - >>> reloaded = graph_dataset.load() - >>> assert nx.is_isomorphic(graph, reloaded) - - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``GMLDataSet``. - - Args: - filepath: Filepath in POSIX format to the NetworkX GML file. - load_args: Arguments passed on to ``networkx.read_gml``. - See the details in - https://networkx.org/documentation/stable/reference/readwrite/generated/networkx.readwrite.gml.read_gml.html - save_args: Arguments passed on to ``networkx.write_gml``. - See the details in - https://networkx.org/documentation/stable/reference/readwrite/generated/networkx.readwrite.gml.write_gml.html - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested keys - `open_args_load` and `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set to `r` when loading - and to `w` when saving. - """ - _fs_args = deepcopy(fs_args) or {} - _fs_open_args_load = _fs_args.pop("open_args_load", {}) - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - # Handle default load and save arguments - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - _fs_open_args_load.setdefault("mode", "rb") - _fs_open_args_save.setdefault("mode", "wb") - self._fs_open_args_load = _fs_open_args_load - self._fs_open_args_save = _fs_open_args_save - - def _load(self) -> networkx.Graph: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: - data = networkx.read_gml(fs_file, **self._load_args) - return data - - def _save(self, data: networkx.Graph) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - networkx.write_gml(data, fs_file, **self._save_args) - self._invalidate_cache() - - def _exists(self) -> bool: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - return self._fs.exists(load_path) - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._load_args, - "save_args": self._save_args, - "version": self._version, - } - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/networkx/graphml_dataset.py b/kedro/extras/datasets/networkx/graphml_dataset.py deleted file mode 100644 index 368459958f..0000000000 --- a/kedro/extras/datasets/networkx/graphml_dataset.py +++ /dev/null @@ -1,141 +0,0 @@ -"""NetworkX ``GraphMLDataSet`` loads and saves graphs to a GraphML file using an underlying -filesystem (e.g.: local, S3, GCS). ``NetworkX`` is used to create GraphML data. -""" - -from copy import deepcopy -from pathlib import PurePosixPath -from typing import Any, Dict - -import fsspec -import networkx - -from kedro.io.core import ( - AbstractVersionedDataset, - Version, - get_filepath_str, - get_protocol_and_path, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class GraphMLDataSet(AbstractVersionedDataset[networkx.Graph, networkx.Graph]): - """``GraphMLDataSet`` loads and saves graphs to a GraphML file using an - underlying filesystem (e.g.: local, S3, GCS). ``NetworkX`` is used to - create GraphML data. - See https://networkx.org/documentation/stable/tutorial.html for details. - - Example: - :: - - >>> from kedro.extras.datasets.networkx import GraphMLDataSet - >>> import networkx as nx - >>> graph = nx.complete_graph(100) - >>> graph_dataset = GraphMLDataSet(filepath="test.graphml") - >>> graph_dataset.save(graph) - >>> reloaded = graph_dataset.load() - >>> assert nx.is_isomorphic(graph, reloaded) - - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``GraphMLDataSet``. - - Args: - filepath: Filepath in POSIX format to the NetworkX GraphML file. - load_args: Arguments passed on to ``networkx.read_graphml``. - See the details in - https://networkx.org/documentation/stable/reference/readwrite/generated/networkx.readwrite.graphml.read_graphml.html - save_args: Arguments passed on to ``networkx.write_graphml``. - See the details in - https://networkx.org/documentation/stable/reference/readwrite/generated/networkx.readwrite.graphml.write_graphml.html - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested keys - `open_args_load` and `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set to `r` when loading - and to `w` when saving. - """ - _fs_args = deepcopy(fs_args) or {} - _fs_open_args_load = _fs_args.pop("open_args_load", {}) - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - # Handle default load and save arguments - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - _fs_open_args_load.setdefault("mode", "rb") - _fs_open_args_save.setdefault("mode", "wb") - self._fs_open_args_load = _fs_open_args_load - self._fs_open_args_save = _fs_open_args_save - - def _load(self) -> networkx.Graph: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: - return networkx.read_graphml(fs_file, **self._load_args) - - def _save(self, data: networkx.Graph) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - networkx.write_graphml(data, fs_file, **self._save_args) - self._invalidate_cache() - - def _exists(self) -> bool: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - return self._fs.exists(load_path) - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._load_args, - "save_args": self._save_args, - "version": self._version, - } - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/networkx/json_dataset.py b/kedro/extras/datasets/networkx/json_dataset.py deleted file mode 100644 index 60db837a91..0000000000 --- a/kedro/extras/datasets/networkx/json_dataset.py +++ /dev/null @@ -1,148 +0,0 @@ -"""``JSONDataSet`` loads and saves graphs to a JSON file using an underlying -filesystem (e.g.: local, S3, GCS). ``NetworkX`` is used to create JSON data. -""" - -import json -from copy import deepcopy -from pathlib import PurePosixPath -from typing import Any, Dict - -import fsspec -import networkx - -from kedro.io.core import ( - AbstractVersionedDataset, - Version, - get_filepath_str, - get_protocol_and_path, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class JSONDataSet(AbstractVersionedDataset[networkx.Graph, networkx.Graph]): - """NetworkX ``JSONDataSet`` loads and saves graphs to a JSON file using an - underlying filesystem (e.g.: local, S3, GCS). ``NetworkX`` is used to - create JSON data. - See https://networkx.org/documentation/stable/tutorial.html for details. - - Example: - :: - - >>> from kedro.extras.datasets.networkx import JSONDataSet - >>> import networkx as nx - >>> graph = nx.complete_graph(100) - >>> graph_dataset = JSONDataSet(filepath="test.json") - >>> graph_dataset.save(graph) - >>> reloaded = graph_dataset.load() - >>> assert nx.is_isomorphic(graph, reloaded) - - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``JSONDataSet``. - - Args: - filepath: Filepath in POSIX format to the NetworkX graph JSON file. - load_args: Arguments passed on to ``networkx.node_link_graph``. - See the details in - https://networkx.org/documentation/networkx-1.9.1/reference/generated/networkx.readwrite.json_graph.node_link_graph.html - save_args: Arguments passed on to ``networkx.node_link_data``. - See the details in - https://networkx.org/documentation/networkx-1.9.1/reference/generated/networkx.readwrite.json_graph.node_link_data.html - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested keys - `open_args_load` and `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set to `r` when loading - and to `w` when saving. - """ - _fs_args = deepcopy(fs_args) or {} - _fs_open_args_load = _fs_args.pop("open_args_load", {}) - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - # Handle default load and save arguments - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - _fs_open_args_save.setdefault("mode", "w") - self._fs_open_args_load = _fs_open_args_load - self._fs_open_args_save = _fs_open_args_save - - def _load(self) -> networkx.Graph: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: - json_payload = json.load(fs_file) - - return networkx.node_link_graph(json_payload, **self._load_args) - - def _save(self, data: networkx.Graph) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - - json_graph = networkx.node_link_data(data, **self._save_args) - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - json.dump(json_graph, fs_file) - - self._invalidate_cache() - - def _exists(self) -> bool: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - - return self._fs.exists(load_path) - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._load_args, - "save_args": self._save_args, - "version": self._version, - } - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/pandas/__init__.py b/kedro/extras/datasets/pandas/__init__.py deleted file mode 100644 index 2a8ba76371..0000000000 --- a/kedro/extras/datasets/pandas/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -"""``AbstractDataset`` implementations that produce pandas DataFrames.""" - -__all__ = [ - "CSVDataSet", - "ExcelDataSet", - "FeatherDataSet", - "GBQTableDataSet", - "GBQQueryDataSet", - "HDFDataSet", - "JSONDataSet", - "ParquetDataSet", - "SQLQueryDataSet", - "SQLTableDataSet", - "XMLDataSet", - "GenericDataSet", -] - -from contextlib import suppress - -with suppress(ImportError): - from .csv_dataset import CSVDataSet -with suppress(ImportError): - from .excel_dataset import ExcelDataSet -with suppress(ImportError): - from .feather_dataset import FeatherDataSet -with suppress(ImportError): - from .gbq_dataset import GBQQueryDataSet, GBQTableDataSet -with suppress(ImportError): - from .hdf_dataset import HDFDataSet -with suppress(ImportError): - from .json_dataset import JSONDataSet -with suppress(ImportError): - from .parquet_dataset import ParquetDataSet -with suppress(ImportError): - from .sql_dataset import SQLQueryDataSet, SQLTableDataSet -with suppress(ImportError): - from .xml_dataset import XMLDataSet -with suppress(ImportError): - from .generic_dataset import GenericDataSet diff --git a/kedro/extras/datasets/pandas/csv_dataset.py b/kedro/extras/datasets/pandas/csv_dataset.py deleted file mode 100644 index 26816da5d4..0000000000 --- a/kedro/extras/datasets/pandas/csv_dataset.py +++ /dev/null @@ -1,193 +0,0 @@ -"""``CSVDataSet`` loads/saves data from/to a CSV file using an underlying -filesystem (e.g.: local, S3, GCS). It uses pandas to handle the CSV file. -""" -import logging -from copy import deepcopy -from io import BytesIO -from pathlib import PurePosixPath -from typing import Any, Dict - -import fsspec -import pandas as pd - -from kedro.io.core import ( - PROTOCOL_DELIMITER, - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -logger = logging.getLogger(__name__) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class CSVDataSet(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]): - """``CSVDataSet`` loads/saves data from/to a CSV file using an underlying - filesystem (e.g.: local, S3, GCS). It uses pandas to handle the CSV file. - - Example usage for the - `YAML API `_: - - .. code-block:: yaml - - cars: - type: pandas.CSVDataSet - filepath: data/01_raw/company/cars.csv - load_args: - sep: "," - na_values: ["#NA", NA] - save_args: - index: False - date_format: "%Y-%m-%d %H:%M" - decimal: . - - motorbikes: - type: pandas.CSVDataSet - filepath: s3://your_bucket/data/02_intermediate/company/motorbikes.csv - credentials: dev_s3 - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.pandas import CSVDataSet - >>> import pandas as pd - >>> - >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) - >>> - >>> data_set = CSVDataSet(filepath="test.csv") - >>> data_set.save(data) - >>> reloaded = data_set.load() - >>> assert data.equals(reloaded) - - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {"index": False} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``CSVDataSet`` pointing to a concrete CSV file - on a specific filesystem. - - Args: - filepath: Filepath in POSIX format to a CSV file prefixed with a protocol like `s3://`. - If prefix is not provided, `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - Note: `http(s)` doesn't support versioning. - load_args: Pandas options for loading CSV files. - Here you can find all available arguments: - https://pandas.pydata.org/pandas-docs/stable/generated/pandas.read_csv.html - All defaults are preserved. - save_args: Pandas options for saving CSV files. - Here you can find all available arguments: - https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.to_csv.html - All defaults are preserved, but "index", which is set to False. - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``). - """ - _fs_args = deepcopy(fs_args) or {} - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._storage_options = {**_credentials, **_fs_args} - self._fs = fsspec.filesystem(self._protocol, **self._storage_options) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - # Handle default load and save arguments - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - if "storage_options" in self._save_args or "storage_options" in self._load_args: - logger.warning( - "Dropping 'storage_options' for %s, " - "please specify them under 'fs_args' or 'credentials'.", - self._filepath, - ) - self._save_args.pop("storage_options", None) - self._load_args.pop("storage_options", None) - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._load_args, - "save_args": self._save_args, - "version": self._version, - } - - def _load(self) -> pd.DataFrame: - load_path = str(self._get_load_path()) - if self._protocol == "file": - # file:// protocol seems to misbehave on Windows - # (), - # so we don't join that back to the filepath; - # storage_options also don't work with local paths - return pd.read_csv(load_path, **self._load_args) - - load_path = f"{self._protocol}{PROTOCOL_DELIMITER}{load_path}" - return pd.read_csv( - load_path, storage_options=self._storage_options, **self._load_args - ) - - def _save(self, data: pd.DataFrame) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - - buf = BytesIO() - data.to_csv(path_or_buf=buf, **self._save_args) - - with self._fs.open(save_path, mode="wb") as fs_file: - fs_file.write(buf.getvalue()) - - self._invalidate_cache() - - def _exists(self) -> bool: - try: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DatasetError: - return False - - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/pandas/excel_dataset.py b/kedro/extras/datasets/pandas/excel_dataset.py deleted file mode 100644 index ebf5015b72..0000000000 --- a/kedro/extras/datasets/pandas/excel_dataset.py +++ /dev/null @@ -1,263 +0,0 @@ -"""``ExcelDataSet`` loads/saves data from/to a Excel file using an underlying -filesystem (e.g.: local, S3, GCS). It uses pandas to handle the Excel file. -""" -import logging -from copy import deepcopy -from io import BytesIO -from pathlib import PurePosixPath -from typing import Any, Dict, Union - -import fsspec -import pandas as pd - -from kedro.io.core import ( - PROTOCOL_DELIMITER, - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -logger = logging.getLogger(__name__) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class ExcelDataSet( - AbstractVersionedDataset[ - Union[pd.DataFrame, Dict[str, pd.DataFrame]], - Union[pd.DataFrame, Dict[str, pd.DataFrame]], - ] -): - """``ExcelDataSet`` loads/saves data from/to a Excel file using an underlying - filesystem (e.g.: local, S3, GCS). It uses pandas to handle the Excel file. - - Example usage for the - `YAML API `_: - - .. code-block:: yaml - - rockets: - type: pandas.ExcelDataSet - filepath: gcs://your_bucket/rockets.xlsx - fs_args: - project: my-project - credentials: my_gcp_credentials - save_args: - sheet_name: Sheet1 - load_args: - sheet_name: Sheet1 - - shuttles: - type: pandas.ExcelDataSet - filepath: data/01_raw/shuttles.xlsx - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.pandas import ExcelDataSet - >>> import pandas as pd - >>> - >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) - >>> - >>> data_set = ExcelDataSet(filepath="test.xlsx") - >>> data_set.save(data) - >>> reloaded = data_set.load() - >>> assert data.equals(reloaded) - - To save a multi-sheet Excel file, no special ``save_args`` are required. - Instead, return a dictionary of ``Dict[str, pd.DataFrame]`` where the string - keys are your sheet names. - - Example usage for the - `YAML API `_ - for a multi-sheet Excel file: - - .. code-block:: yaml - - trains: - type: pandas.ExcelDataSet - filepath: data/02_intermediate/company/trains.xlsx - load_args: - sheet_name: [Sheet1, Sheet2, Sheet3] - - Example usage for the - `Python API `_ - for a multi-sheet Excel file: - :: - - >>> from kedro.extras.datasets.pandas import ExcelDataSet - >>> import pandas as pd - >>> - >>> dataframe = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) - >>> another_dataframe = pd.DataFrame({"x": [10, 20], "y": ["hello", "world"]}) - >>> multiframe = {"Sheet1": dataframe, "Sheet2": another_dataframe} - >>> data_set = ExcelDataSet(filepath="test.xlsx", load_args = {"sheet_name": None}) - >>> data_set.save(multiframe) - >>> reloaded = data_set.load() - >>> assert multiframe["Sheet1"].equals(reloaded["Sheet1"]) - >>> assert multiframe["Sheet2"].equals(reloaded["Sheet2"]) - - """ - - DEFAULT_LOAD_ARGS = {"engine": "openpyxl"} - DEFAULT_SAVE_ARGS = {"index": False} - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - engine: str = "openpyxl", - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``ExcelDataSet`` pointing to a concrete Excel file - on a specific filesystem. - - Args: - filepath: Filepath in POSIX format to a Excel file prefixed with a protocol like - `s3://`. If prefix is not provided, `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - Note: `http(s)` doesn't support versioning. - engine: The engine used to write to Excel files. The default - engine is 'openpyxl'. - load_args: Pandas options for loading Excel files. - Here you can find all available arguments: - https://pandas.pydata.org/pandas-docs/stable/generated/pandas.read_excel.html - All defaults are preserved, but "engine", which is set to "openpyxl". - Supports multi-sheet Excel files (include `sheet_name = None` in `load_args`). - save_args: Pandas options for saving Excel files. - Here you can find all available arguments: - https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.to_excel.html - All defaults are preserved, but "index", which is set to False. - If you would like to specify options for the `ExcelWriter`, - you can include them under the "writer" key. Here you can - find all available arguments: - https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.ExcelWriter.html - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``). - - Raises: - DatasetError: If versioning is enabled while in append mode. - """ - _fs_args = deepcopy(fs_args) or {} - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._storage_options = {**_credentials, **_fs_args} - self._fs = fsspec.filesystem(self._protocol, **self._storage_options) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - # Handle default load arguments - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - - # Handle default save arguments - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - self._writer_args = self._save_args.pop("writer", {}) # type: ignore - self._writer_args.setdefault("engine", engine or "openpyxl") # type: ignore - - if version and self._writer_args.get("mode") == "a": # type: ignore - raise DatasetError( - "'ExcelDataSet' doesn't support versioning in append mode." - ) - - if "storage_options" in self._save_args or "storage_options" in self._load_args: - logger.warning( - "Dropping 'storage_options' for %s, " - "please specify them under 'fs_args' or 'credentials'.", - self._filepath, - ) - self._save_args.pop("storage_options", None) - self._load_args.pop("storage_options", None) - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._protocol, - "load_args": self._load_args, - "save_args": self._save_args, - "writer_args": self._writer_args, - "version": self._version, - } - - def _load(self) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: - load_path = str(self._get_load_path()) - if self._protocol == "file": - # file:// protocol seems to misbehave on Windows - # (), - # so we don't join that back to the filepath; - # storage_options also don't work with local paths - return pd.read_excel(load_path, **self._load_args) - - load_path = f"{self._protocol}{PROTOCOL_DELIMITER}{load_path}" - return pd.read_excel( - load_path, storage_options=self._storage_options, **self._load_args - ) - - def _save(self, data: Union[pd.DataFrame, Dict[str, pd.DataFrame]]) -> None: - output = BytesIO() - save_path = get_filepath_str(self._get_save_path(), self._protocol) - - # pylint: disable=abstract-class-instantiated - with pd.ExcelWriter(output, **self._writer_args) as writer: - if isinstance(data, dict): - for sheet_name, sheet_data in data.items(): - sheet_data.to_excel( - writer, sheet_name=sheet_name, **self._save_args - ) - else: - data.to_excel(writer, **self._save_args) - - with self._fs.open(save_path, mode="wb") as fs_file: - fs_file.write(output.getvalue()) - - self._invalidate_cache() - - def _exists(self) -> bool: - try: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DatasetError: - return False - - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/pandas/feather_dataset.py b/kedro/extras/datasets/pandas/feather_dataset.py deleted file mode 100644 index 445cd9758a..0000000000 --- a/kedro/extras/datasets/pandas/feather_dataset.py +++ /dev/null @@ -1,189 +0,0 @@ -"""``FeatherDataSet`` is a data set used to load and save data to feather files -using an underlying filesystem (e.g.: local, S3, GCS). The underlying functionality -is supported by pandas, so it supports all operations the pandas supports. -""" -import logging -from copy import deepcopy -from io import BytesIO -from pathlib import PurePosixPath -from typing import Any, Dict - -import fsspec -import pandas as pd - -from kedro.io.core import ( - PROTOCOL_DELIMITER, - AbstractVersionedDataset, - Version, - get_filepath_str, - get_protocol_and_path, -) - -logger = logging.getLogger(__name__) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class FeatherDataSet(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]): - """``FeatherDataSet`` loads and saves data to a feather file using an - underlying filesystem (e.g.: local, S3, GCS). The underlying functionality - is supported by pandas, so it supports all allowed pandas options - for loading and saving csv files. - - Example usage for the - `YAML API `_: - - .. code-block:: yaml - - cars: - type: pandas.FeatherDataSet - filepath: data/01_raw/company/cars.feather - load_args: - columns: ['col1', 'col2', 'col3'] - use_threads: True - - motorbikes: - type: pandas.FeatherDataSet - filepath: s3://your_bucket/data/02_intermediate/company/motorbikes.feather - credentials: dev_s3 - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.pandas import FeatherDataSet - >>> import pandas as pd - >>> - >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) - >>> - >>> data_set = FeatherDataSet(filepath="test.feather") - >>> - >>> data_set.save(data) - >>> reloaded = data_set.load() - >>> - >>> assert data.equals(reloaded) - - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``FeatherDataSet`` pointing to a concrete - filepath. - - Args: - filepath: Filepath in POSIX format to a feather file prefixed with a protocol like - `s3://`. If prefix is not provided, `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - Note: `http(s)` doesn't support versioning. - load_args: Pandas options for loading feather files. - Here you can find all available arguments: - https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_feather.html - All defaults are preserved. - save_args: Pandas options for saving feather files. - Here you can find all available arguments: - https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_feather.html - All defaults are preserved. - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``). - """ - _fs_args = deepcopy(fs_args) or {} - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._storage_options = {**_credentials, **_fs_args} - self._fs = fsspec.filesystem(self._protocol, **self._storage_options) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - # Handle default load argument - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - if "storage_options" in self._save_args or "storage_options" in self._load_args: - logger.warning( - "Dropping 'storage_options' for %s, " - "please specify them under 'fs_args' or 'credentials'.", - self._filepath, - ) - self._save_args.pop("storage_options", None) - self._load_args.pop("storage_options", None) - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._protocol, - "load_args": self._load_args, - "version": self._version, - } - - def _load(self) -> pd.DataFrame: - load_path = str(self._get_load_path()) - if self._protocol == "file": - # file:// protocol seems to misbehave on Windows - # (), - # so we don't join that back to the filepath; - # storage_options also don't work with local paths - return pd.read_feather(load_path, **self._load_args) - - load_path = f"{self._protocol}{PROTOCOL_DELIMITER}{load_path}" - return pd.read_feather( - load_path, storage_options=self._storage_options, **self._load_args - ) - - def _save(self, data: pd.DataFrame) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - - buf = BytesIO() - data.to_feather(buf, **self._save_args) - - with self._fs.open(save_path, mode="wb") as fs_file: - fs_file.write(buf.getvalue()) - - self._invalidate_cache() - - def _exists(self) -> bool: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/pandas/gbq_dataset.py b/kedro/extras/datasets/pandas/gbq_dataset.py deleted file mode 100644 index 5a7c460c7c..0000000000 --- a/kedro/extras/datasets/pandas/gbq_dataset.py +++ /dev/null @@ -1,313 +0,0 @@ -"""``GBQTableDataSet`` loads and saves data from/to Google BigQuery. It uses pandas-gbq -to read and write from/to BigQuery table. -""" - -import copy -from pathlib import PurePosixPath -from typing import Any, Dict, NoReturn, Union - -import fsspec -import pandas as pd -from google.cloud import bigquery -from google.cloud.exceptions import NotFound -from google.oauth2.credentials import Credentials - -from kedro.io.core import ( - AbstractDataset, - DatasetError, - get_filepath_str, - get_protocol_and_path, - validate_on_forbidden_chars, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class GBQTableDataSet(AbstractDataset[None, pd.DataFrame]): - """``GBQTableDataSet`` loads and saves data from/to Google BigQuery. - It uses pandas-gbq to read and write from/to BigQuery table. - - Example usage for the - `YAML API `_: - - .. code-block:: yaml - - vehicles: - type: pandas.GBQTableDataSet - dataset: big_query_dataset - table_name: big_query_table - project: my-project - credentials: gbq-creds - load_args: - reauth: True - save_args: - chunk_size: 100 - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.pandas import GBQTableDataSet - >>> import pandas as pd - >>> - >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) - >>> - >>> data_set = GBQTableDataSet('dataset', - >>> 'table_name', - >>> project='my-project') - >>> data_set.save(data) - >>> reloaded = data_set.load() - >>> - >>> assert data.equals(reloaded) - - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {"progress_bar": False} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - dataset: str, - table_name: str, - project: str = None, - credentials: Union[Dict[str, Any], Credentials] = None, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``GBQTableDataSet``. - - Args: - dataset: Google BigQuery dataset. - table_name: Google BigQuery table name. - project: Google BigQuery Account project ID. - Optional when available from the environment. - https://cloud.google.com/resource-manager/docs/creating-managing-projects - credentials: Credentials for accessing Google APIs. - Either ``google.auth.credentials.Credentials`` object or dictionary with - parameters required to instantiate ``google.oauth2.credentials.Credentials``. - Here you can find all the arguments: - https://google-auth.readthedocs.io/en/latest/reference/google.oauth2.credentials.html - load_args: Pandas options for loading BigQuery table into DataFrame. - Here you can find all available arguments: - https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_gbq.html - All defaults are preserved. - save_args: Pandas options for saving DataFrame to BigQuery table. - Here you can find all available arguments: - https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_gbq.html - All defaults are preserved, but "progress_bar", which is set to False. - - Raises: - DatasetError: When ``load_args['location']`` and ``save_args['location']`` - are different. - """ - # Handle default load and save arguments - self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - self._validate_location() - validate_on_forbidden_chars(dataset=dataset, table_name=table_name) - - if isinstance(credentials, dict): - credentials = Credentials(**credentials) - - self._dataset = dataset - self._table_name = table_name - self._project_id = project - self._credentials = credentials - self._client = bigquery.Client( - project=self._project_id, - credentials=self._credentials, - location=self._save_args.get("location"), - ) - - def _describe(self) -> Dict[str, Any]: - return { - "dataset": self._dataset, - "table_name": self._table_name, - "load_args": self._load_args, - "save_args": self._save_args, - } - - def _load(self) -> pd.DataFrame: - sql = f"select * from {self._dataset}.{self._table_name}" # nosec - self._load_args.setdefault("query", sql) - return pd.read_gbq( - project_id=self._project_id, - credentials=self._credentials, - **self._load_args, - ) - - def _save(self, data: pd.DataFrame) -> None: - data.to_gbq( - f"{self._dataset}.{self._table_name}", - project_id=self._project_id, - credentials=self._credentials, - **self._save_args, - ) - - def _exists(self) -> bool: - table_ref = self._client.dataset(self._dataset).table(self._table_name) - try: - self._client.get_table(table_ref) - return True - except NotFound: - return False - - def _validate_location(self): - save_location = self._save_args.get("location") - load_location = self._load_args.get("location") - - if save_location != load_location: - raise DatasetError( - """"load_args['location']" is different from "save_args['location']". """ - "The 'location' defines where BigQuery data is stored, therefore has " - "to be the same for save and load args. " - "Details: https://cloud.google.com/bigquery/docs/locations" - ) - - -class GBQQueryDataSet(AbstractDataset[None, pd.DataFrame]): - """``GBQQueryDataSet`` loads data from a provided SQL query from Google - BigQuery. It uses ``pandas.read_gbq`` which itself uses ``pandas-gbq`` - internally to read from BigQuery table. Therefore it supports all allowed - pandas options on ``read_gbq``. - - Example adding a catalog entry with the ``YAML API``: - - .. code-block:: yaml - - >>> vehicles: - >>> type: pandas.GBQQueryDataSet - >>> sql: "select shuttle, shuttle_id from spaceflights.shuttles;" - >>> project: my-project - >>> credentials: gbq-creds - >>> load_args: - >>> reauth: True - - - Example using Python API: - :: - - >>> from kedro.extras.datasets.pandas import GBQQueryDataSet - >>> - >>> sql = "SELECT * FROM dataset_1.table_a" - >>> - >>> data_set = GBQQueryDataSet(sql, project='my-project') - >>> - >>> sql_data = data_set.load() - >>> - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - sql: str = None, - project: str = None, - credentials: Union[Dict[str, Any], Credentials] = None, - load_args: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - filepath: str = None, - ) -> None: - """Creates a new instance of ``GBQQueryDataSet``. - - Args: - sql: The sql query statement. - project: Google BigQuery Account project ID. - Optional when available from the environment. - https://cloud.google.com/resource-manager/docs/creating-managing-projects - credentials: Credentials for accessing Google APIs. - Either ``google.auth.credentials.Credentials`` object or dictionary with - parameters required to instantiate ``google.oauth2.credentials.Credentials``. - Here you can find all the arguments: - https://google-auth.readthedocs.io/en/latest/reference/google.oauth2.credentials.html - load_args: Pandas options for loading BigQuery table into DataFrame. - Here you can find all available arguments: - https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_gbq.html - All defaults are preserved. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``) used for reading the - SQL query from filepath. - filepath: A path to a file with a sql query statement. - - Raises: - DatasetError: When ``sql`` and ``filepath`` parameters are either both empty - or both provided, as well as when the `save()` method is invoked. - """ - if sql and filepath: - raise DatasetError( - "'sql' and 'filepath' arguments cannot both be provided." - "Please only provide one." - ) - - if not (sql or filepath): - raise DatasetError( - "'sql' and 'filepath' arguments cannot both be empty." - "Please provide a sql query or path to a sql query file." - ) - - # Handle default load arguments - self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - - self._project_id = project - - if isinstance(credentials, dict): - credentials = Credentials(**credentials) - - self._credentials = credentials - self._client = bigquery.Client( - project=self._project_id, - credentials=self._credentials, - location=self._load_args.get("location"), - ) - - # load sql query from arg or from file - if sql: - self._load_args["query"] = sql - self._filepath = None - else: - # filesystem for loading sql file - _fs_args = copy.deepcopy(fs_args) or {} - _fs_credentials = _fs_args.pop("credentials", {}) - protocol, path = get_protocol_and_path(str(filepath)) - - self._protocol = protocol - self._fs = fsspec.filesystem(self._protocol, **_fs_credentials, **_fs_args) - self._filepath = path - - def _describe(self) -> Dict[str, Any]: - load_args = copy.deepcopy(self._load_args) - desc = {} - desc["sql"] = str(load_args.pop("query", None)) - desc["filepath"] = str(self._filepath) - desc["load_args"] = str(load_args) - - return desc - - def _load(self) -> pd.DataFrame: - load_args = copy.deepcopy(self._load_args) - - if self._filepath: - load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol) - with self._fs.open(load_path, mode="r") as fs_file: - load_args["query"] = fs_file.read() - - return pd.read_gbq( - project_id=self._project_id, - credentials=self._credentials, - **load_args, - ) - - def _save(self, data: None) -> NoReturn: - raise DatasetError("'save' is not supported on GBQQueryDataSet") diff --git a/kedro/extras/datasets/pandas/generic_dataset.py b/kedro/extras/datasets/pandas/generic_dataset.py deleted file mode 100644 index 9d173d6524..0000000000 --- a/kedro/extras/datasets/pandas/generic_dataset.py +++ /dev/null @@ -1,247 +0,0 @@ -"""``GenericDataSet`` loads/saves data from/to a data file using an underlying -filesystem (e.g.: local, S3, GCS). It uses pandas to handle the -type of read/write target. -""" -from copy import deepcopy -from pathlib import PurePosixPath -from typing import Any, Dict - -import fsspec -import pandas as pd - -from kedro.io.core import ( - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -NON_FILE_SYSTEM_TARGETS = [ - "clipboard", - "numpy", - "sql", - "period", - "records", - "timestamp", - "xarray", - "sql_table", -] - - -class GenericDataSet(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]): - """`pandas.GenericDataSet` loads/saves data from/to a data file using an underlying - filesystem (e.g.: local, S3, GCS). It uses pandas to dynamically select the - appropriate type of read/write target on a best effort basis. - - Example usage for the - `YAML API `_: - - .. code-block:: yaml - - cars: - type: pandas.GenericDataSet - file_format: csv - filepath: s3://data/01_raw/company/cars.csv - load_args: - sep: "," - na_values: ["#NA", NA] - save_args: - index: False - date_format: "%Y-%m-%d" - - This second example is able to load a SAS7BDAT file via the ``pd.read_sas`` method. - Trying to save this dataset will raise a ``DatasetError`` since pandas does not provide an - equivalent ``pd.DataFrame.to_sas`` write method. - - .. code-block:: yaml - - flights: - type: pandas.GenericDataSet - file_format: sas - filepath: data/01_raw/airplanes.sas7bdat - load_args: - format: sas7bdat - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.pandas import GenericDataSet - >>> import pandas as pd - >>> - >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) - >>> - >>> data_set = GenericDataSet(filepath="test.csv", file_format='csv') - >>> data_set.save(data) - >>> reloaded = data_set.load() - >>> assert data.equals(reloaded) - - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - file_format: str, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ): - """Creates a new instance of ``GenericDataSet`` pointing to a concrete data file - on a specific filesystem. The appropriate pandas load/save methods are - dynamically identified by string matching on a best effort basis. - - Args: - filepath: Filepath in POSIX format to a file prefixed with a protocol like `s3://`. - If prefix is not provided, `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - Key assumption: The first argument of either load/save method points to a - filepath/buffer/io type location. There are some read/write targets such - as 'clipboard' or 'records' that will fail since they do not take a - filepath like argument. - file_format: String which is used to match the appropriate load/save method on a best - effort basis. For example if 'csv' is passed in the `pandas.read_csv` and - `pandas.DataFrame.to_csv` will be identified. An error will be raised unless - at least one matching `read_{file_format}` or `to_{file_format}` method is - identified. - load_args: Pandas options for loading files. - Here you can find all available arguments: - https://pandas.pydata.org/pandas-docs/stable/reference/io.html - All defaults are preserved. - save_args: Pandas options for saving files. - Here you can find all available arguments: - https://pandas.pydata.org/pandas-docs/stable/reference/io.html - All defaults are preserved, but "index", which is set to False. - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested keys - `open_args_load` and `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set to `r` when loading - and to `w` when saving. - - Raises: - DatasetError: Will be raised if at least less than one appropriate - read or write methods are identified. - """ - - self._file_format = file_format.lower() - - _fs_args = deepcopy(fs_args) or {} - _fs_open_args_load = _fs_args.pop("open_args_load", {}) - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - _fs_open_args_save.setdefault("mode", "w") - self._fs_open_args_load = _fs_open_args_load - self._fs_open_args_save = _fs_open_args_save - - def _ensure_file_system_target(self) -> None: - # Fail fast if provided a known non-filesystem target - if self._file_format in NON_FILE_SYSTEM_TARGETS: - raise DatasetError( - f"Cannot create a dataset of file_format '{self._file_format}' as it " - f"does not support a filepath target/source." - ) - - def _load(self) -> pd.DataFrame: - - self._ensure_file_system_target() - - load_path = get_filepath_str(self._get_load_path(), self._protocol) - load_method = getattr(pd, f"read_{self._file_format}", None) - if load_method: - with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: - return load_method(fs_file, **self._load_args) - raise DatasetError( - f"Unable to retrieve 'pandas.read_{self._file_format}' method, please ensure that your " - "'file_format' parameter has been defined correctly as per the Pandas API " - "https://pandas.pydata.org/docs/reference/io.html" - ) - - def _save(self, data: pd.DataFrame) -> None: - - self._ensure_file_system_target() - - save_path = get_filepath_str(self._get_save_path(), self._protocol) - save_method = getattr(data, f"to_{self._file_format}", None) - if save_method: - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - # KEY ASSUMPTION - first argument is path/buffer/io - save_method(fs_file, **self._save_args) - self._invalidate_cache() - else: - raise DatasetError( - f"Unable to retrieve 'pandas.DataFrame.to_{self._file_format}' method, please " - "ensure that your 'file_format' parameter has been defined correctly as " - "per the Pandas API " - "https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html" - ) - - def _exists(self) -> bool: - try: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DatasetError: - return False - - return self._fs.exists(load_path) - - def _describe(self) -> Dict[str, Any]: - return { - "file_format": self._file_format, - "filepath": self._filepath, - "protocol": self._protocol, - "load_args": self._load_args, - "save_args": self._save_args, - "version": self._version, - } - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/pandas/hdf_dataset.py b/kedro/extras/datasets/pandas/hdf_dataset.py deleted file mode 100644 index aa02434776..0000000000 --- a/kedro/extras/datasets/pandas/hdf_dataset.py +++ /dev/null @@ -1,206 +0,0 @@ -"""``HDFDataSet`` loads/saves data from/to a hdf file using an underlying -filesystem (e.g.: local, S3, GCS). It uses pandas.HDFStore to handle the hdf file. -""" -from copy import deepcopy -from pathlib import PurePosixPath -from threading import Lock -from typing import Any, Dict - -import fsspec -import pandas as pd - -from kedro.io.core import ( - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -HDFSTORE_DRIVER = "H5FD_CORE" - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class HDFDataSet(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]): - """``HDFDataSet`` loads/saves data from/to a hdf file using an underlying - filesystem (e.g. local, S3, GCS). It uses pandas.HDFStore to handle the hdf file. - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - hdf_dataset: - type: pandas.HDFDataSet - filepath: s3://my_bucket/raw/sensor_reading.h5 - credentials: aws_s3_creds - key: data - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.pandas import HDFDataSet - >>> import pandas as pd - >>> - >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) - >>> - >>> data_set = HDFDataSet(filepath="test.h5", key='data') - >>> data_set.save(data) - >>> reloaded = data_set.load() - >>> assert data.equals(reloaded) - - """ - - # _lock is a class attribute that will be shared across all the instances. - # It is used to make dataset safe for threads. - _lock = Lock() - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - key: str, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``HDFDataSet`` pointing to a concrete hdf file - on a specific filesystem. - - Args: - filepath: Filepath in POSIX format to a hdf file prefixed with a protocol like `s3://`. - If prefix is not provided, `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - Note: `http(s)` doesn't support versioning. - key: Identifier to the group in the HDF store. - load_args: PyTables options for loading hdf files. - You can find all available arguments at: - https://www.pytables.org/usersguide/libref/top_level.html#tables.open_file - All defaults are preserved. - save_args: PyTables options for saving hdf files. - You can find all available arguments at: - https://www.pytables.org/usersguide/libref/top_level.html#tables.open_file - All defaults are preserved. - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested keys - `open_args_load` and `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set `wb` when saving. - """ - _fs_args = deepcopy(fs_args) or {} - _fs_open_args_load = _fs_args.pop("open_args_load", {}) - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - self._key = key - - # Handle default load and save arguments - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - _fs_open_args_save.setdefault("mode", "wb") - self._fs_open_args_load = _fs_open_args_load - self._fs_open_args_save = _fs_open_args_save - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "key": self._key, - "protocol": self._protocol, - "load_args": self._load_args, - "save_args": self._save_args, - "version": self._version, - } - - def _load(self) -> pd.DataFrame: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - - with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: - binary_data = fs_file.read() - - with HDFDataSet._lock: - # Set driver_core_backing_store to False to disable saving - # contents of the in-memory h5file to disk - with pd.HDFStore( - "in-memory-load-file", - mode="r", - driver=HDFSTORE_DRIVER, - driver_core_backing_store=0, - driver_core_image=binary_data, - **self._load_args, - ) as store: - return store[self._key] - - def _save(self, data: pd.DataFrame) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - - with HDFDataSet._lock: - with pd.HDFStore( - "in-memory-save-file", - mode="w", - driver=HDFSTORE_DRIVER, - driver_core_backing_store=0, - **self._save_args, - ) as store: - store.put(self._key, data, format="table") - # pylint: disable=protected-access - binary_data = store._handle.get_file_image() - - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - fs_file.write(binary_data) - - self._invalidate_cache() - - def _exists(self) -> bool: - try: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DatasetError: - return False - - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/pandas/json_dataset.py b/kedro/extras/datasets/pandas/json_dataset.py deleted file mode 100644 index c2cf971bb9..0000000000 --- a/kedro/extras/datasets/pandas/json_dataset.py +++ /dev/null @@ -1,188 +0,0 @@ -"""``JSONDataSet`` loads/saves data from/to a JSON file using an underlying -filesystem (e.g.: local, S3, GCS). It uses pandas to handle the JSON file. -""" -import logging -from copy import deepcopy -from io import BytesIO -from pathlib import PurePosixPath -from typing import Any, Dict - -import fsspec -import pandas as pd - -from kedro.io.core import ( - PROTOCOL_DELIMITER, - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -logger = logging.getLogger(__name__) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class JSONDataSet(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]): - """``JSONDataSet`` loads/saves data from/to a JSON file using an underlying - filesystem (e.g.: local, S3, GCS). It uses pandas to handle the json file. - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - clickstream_dataset: - type: pandas.JSONDataSet - filepath: abfs://landing_area/primary/click_stream.json - credentials: abfs_creds - - json_dataset: - type: pandas.JSONDataSet - filepath: data/01_raw/Video_Games.json - load_args: - lines: True - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.pandas import JSONDataSet - >>> import pandas as pd - >>> - >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) - >>> - >>> data_set = JSONDataSet(filepath="test.json") - >>> data_set.save(data) - >>> reloaded = data_set.load() - >>> assert data.equals(reloaded) - - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``JSONDataSet`` pointing to a concrete JSON file - on a specific filesystem. - - Args: - filepath: Filepath in POSIX format to a JSON file prefixed with a protocol like `s3://`. - If prefix is not provided `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - Note: `http(s)` doesn't support versioning. - load_args: Pandas options for loading JSON files. - Here you can find all available arguments: - https://pandas.pydata.org/pandas-docs/stable/generated/pandas.read_json.html - All defaults are preserved. - save_args: Pandas options for saving JSON files. - Here you can find all available arguments: - https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.to_json.html - All defaults are preserved, but "index", which is set to False. - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{'token': None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``). - """ - _fs_args = deepcopy(fs_args) or {} - _credentials = deepcopy(credentials) or {} - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._storage_options = {**_credentials, **_fs_args} - self._fs = fsspec.filesystem(self._protocol, **self._storage_options) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - # Handle default load and save arguments - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - if "storage_options" in self._save_args or "storage_options" in self._load_args: - logger.warning( - "Dropping 'storage_options' for %s, " - "please specify them under 'fs_args' or 'credentials'.", - self._filepath, - ) - self._save_args.pop("storage_options", None) - self._load_args.pop("storage_options", None) - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._load_args, - "save_args": self._save_args, - "version": self._version, - } - - def _load(self) -> pd.DataFrame: - load_path = str(self._get_load_path()) - if self._protocol == "file": - # file:// protocol seems to misbehave on Windows - # (), - # so we don't join that back to the filepath; - # storage_options also don't work with local paths - return pd.read_json(load_path, **self._load_args) - - load_path = f"{self._protocol}{PROTOCOL_DELIMITER}{load_path}" - return pd.read_json( - load_path, storage_options=self._storage_options, **self._load_args - ) - - def _save(self, data: pd.DataFrame) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - - buf = BytesIO() - data.to_json(path_or_buf=buf, **self._save_args) - - with self._fs.open(save_path, mode="wb") as fs_file: - fs_file.write(buf.getvalue()) - - self._invalidate_cache() - - def _exists(self) -> bool: - try: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DatasetError: - return False - - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/pandas/parquet_dataset.py b/kedro/extras/datasets/pandas/parquet_dataset.py deleted file mode 100644 index 43c603f2ae..0000000000 --- a/kedro/extras/datasets/pandas/parquet_dataset.py +++ /dev/null @@ -1,231 +0,0 @@ -"""``ParquetDataSet`` loads/saves data from/to a Parquet file using an underlying -filesystem (e.g.: local, S3, GCS). It uses pandas to handle the Parquet file. -""" -import logging -from copy import deepcopy -from io import BytesIO -from pathlib import Path, PurePosixPath -from typing import Any, Dict - -import fsspec -import pandas as pd -import pyarrow.parquet as pq - -from kedro.io.core import ( - PROTOCOL_DELIMITER, - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -logger = logging.getLogger(__name__) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class ParquetDataSet(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]): - """``ParquetDataSet`` loads/saves data from/to a Parquet file using an underlying - filesystem (e.g.: local, S3, GCS). It uses pandas to handle the Parquet file. - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - boats: - type: pandas.ParquetDataSet - filepath: data/01_raw/boats.parquet - load_args: - engine: pyarrow - use_nullable_dtypes: True - save_args: - file_scheme: hive - has_nulls: False - engine: pyarrow - - trucks: - type: pandas.ParquetDataSet - filepath: abfs://container/02_intermediate/trucks.parquet - credentials: dev_abs - load_args: - columns: [name, gear, disp, wt] - index: name - save_args: - compression: GZIP - partition_on: [name] - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.pandas import ParquetDataSet - >>> import pandas as pd - >>> - >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) - >>> - >>> data_set = ParquetDataSet(filepath="test.parquet") - >>> data_set.save(data) - >>> reloaded = data_set.load() - >>> assert data.equals(reloaded) - - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``ParquetDataSet`` pointing to a concrete Parquet file - on a specific filesystem. - - Args: - filepath: Filepath in POSIX format to a Parquet file prefixed with a protocol like - `s3://`. If prefix is not provided, `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - It can also be a path to a directory. If the directory is - provided then it can be used for reading partitioned parquet files. - Note: `http(s)` doesn't support versioning. - load_args: Additional options for loading Parquet file(s). - Here you can find all available arguments when reading single file: - https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_parquet.html - Here you can find all available arguments when reading partitioned datasets: - https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetDataset.html#pyarrow.parquet.ParquetDataset.read - All defaults are preserved. - save_args: Additional saving options for saving Parquet file(s). - Here you can find all available arguments: - https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_parquet.html - All defaults are preserved. ``partition_cols`` is not supported. - version: If specified, should be an instance of ``kedro.io.core.Version``. - If its ``load`` attribute is None, the latest version will be loaded. If - its ``save`` attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``). - """ - _fs_args = deepcopy(fs_args) or {} - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._storage_options = {**_credentials, **_fs_args} - self._fs = fsspec.filesystem(self._protocol, **self._storage_options) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - # Handle default load and save arguments - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - if "storage_options" in self._save_args or "storage_options" in self._load_args: - logger.warning( - "Dropping 'storage_options' for %s, " - "please specify them under 'fs_args' or 'credentials'.", - self._filepath, - ) - self._save_args.pop("storage_options", None) - self._load_args.pop("storage_options", None) - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._load_args, - "save_args": self._save_args, - "version": self._version, - } - - def _load(self) -> pd.DataFrame: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - - if self._fs.isdir(load_path): - # It doesn't work at least on S3 if root folder was created manually - # https://issues.apache.org/jira/browse/ARROW-7867 - data = ( - pq.ParquetDataset(load_path, filesystem=self._fs) - .read(**self._load_args) - .to_pandas() - ) - else: - data = self._load_from_pandas() - - return data - - def _load_from_pandas(self): - load_path = str(self._get_load_path()) - if self._protocol == "file": - # file:// protocol seems to misbehave on Windows - # (), - # so we don't join that back to the filepath; - # storage_options also don't work with local paths - return pd.read_parquet(load_path, **self._load_args) - - load_path = f"{self._protocol}{PROTOCOL_DELIMITER}{load_path}" - return pd.read_parquet( - load_path, storage_options=self._storage_options, **self._load_args - ) - - def _save(self, data: pd.DataFrame) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - - if Path(save_path).is_dir(): - raise DatasetError( - f"Saving {self.__class__.__name__} to a directory is not supported." - ) - - if "partition_cols" in self._save_args: - raise DatasetError( - f"{self.__class__.__name__} does not support save argument " - f"'partition_cols'. Please use 'kedro.io.PartitionedDataSet' instead." - ) - - bytes_buffer = BytesIO() - data.to_parquet(bytes_buffer, **self._save_args) - - with self._fs.open(save_path, mode="wb") as fs_file: - fs_file.write(bytes_buffer.getvalue()) - - self._invalidate_cache() - - def _exists(self) -> bool: - try: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DatasetError: - return False - - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/pandas/sql_dataset.py b/kedro/extras/datasets/pandas/sql_dataset.py deleted file mode 100644 index 03b3c43aee..0000000000 --- a/kedro/extras/datasets/pandas/sql_dataset.py +++ /dev/null @@ -1,467 +0,0 @@ -"""``SQLDataSet`` to load and save data to a SQL backend.""" - -import copy -import re -from pathlib import PurePosixPath -from typing import Any, Dict, NoReturn, Optional - -import fsspec -import pandas as pd -from sqlalchemy import create_engine -from sqlalchemy.exc import NoSuchModuleError - -from kedro.io.core import ( - AbstractDataset, - DatasetError, - get_filepath_str, - get_protocol_and_path, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -__all__ = ["SQLTableDataSet", "SQLQueryDataSet"] - -KNOWN_PIP_INSTALL = { - "psycopg2": "psycopg2", - "mysqldb": "mysqlclient", - "cx_Oracle": "cx_Oracle", -} - -DRIVER_ERROR_MESSAGE = """ -A module/driver is missing when connecting to your SQL server. SQLDataSet - supports SQLAlchemy drivers. Please refer to - https://docs.sqlalchemy.org/en/13/core/engines.html#supported-databases - for more information. -\n\n -""" - - -def _find_known_drivers(module_import_error: ImportError) -> Optional[str]: - """Looks up known keywords in a ``ModuleNotFoundError`` so that it can - provide better guideline for the user. - - Args: - module_import_error: Error raised while connecting to a SQL server. - - Returns: - Instructions for installing missing driver. An empty string is - returned in case error is related to an unknown driver. - - """ - - # module errors contain string "No module name 'module_name'" - # we are trying to extract module_name surrounded by quotes here - res = re.findall(r"'(.*?)'", str(module_import_error.args[0]).lower()) - - # in case module import error does not match our expected pattern - # we have no recommendation - if not res: - return None - - missing_module = res[0] - - if KNOWN_PIP_INSTALL.get(missing_module): - return ( - f"You can also try installing missing driver with\n" - f"\npip install {KNOWN_PIP_INSTALL.get(missing_module)}" - ) - - return None - - -def _get_missing_module_error(import_error: ImportError) -> DatasetError: - missing_module_instruction = _find_known_drivers(import_error) - - if missing_module_instruction is None: - return DatasetError( - f"{DRIVER_ERROR_MESSAGE}Loading failed with error:\n\n{str(import_error)}" - ) - - return DatasetError(f"{DRIVER_ERROR_MESSAGE}{missing_module_instruction}") - - -def _get_sql_alchemy_missing_error() -> DatasetError: - return DatasetError( - "The SQL dialect in your connection is not supported by " - "SQLAlchemy. Please refer to " - "https://docs.sqlalchemy.org/en/13/core/engines.html#supported-databases " - "for more information." - ) - - -class SQLTableDataSet(AbstractDataset[pd.DataFrame, pd.DataFrame]): - """``SQLTableDataSet`` loads data from a SQL table and saves a pandas - dataframe to a table. It uses ``pandas.DataFrame`` internally, - so it supports all allowed pandas options on ``read_sql_table`` and - ``to_sql`` methods. Since Pandas uses SQLAlchemy behind the scenes, when - instantiating ``SQLTableDataSet`` one needs to pass a compatible connection - string either in ``credentials`` (see the example code snippet below) or in - ``load_args`` and ``save_args``. Connection string formats supported by - SQLAlchemy can be found here: - https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls - - ``SQLTableDataSet`` modifies the save parameters and stores - the data with no index. This is designed to make load and save methods - symmetric. - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - shuttles_table_dataset: - type: pandas.SQLTableDataSet - credentials: db_credentials - table_name: shuttles - load_args: - schema: dwschema - save_args: - schema: dwschema - if_exists: replace - - Sample database credentials entry in ``credentials.yml``: - - .. code-block:: yaml - - db_credentials: - con: postgresql://scott:tiger@localhost/test - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.pandas import SQLTableDataSet - >>> import pandas as pd - >>> - >>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], - >>> "col3": [5, 6]}) - >>> table_name = "table_a" - >>> credentials = { - >>> "con": "postgresql://scott:tiger@localhost/test" - >>> } - >>> data_set = SQLTableDataSet(table_name=table_name, - >>> credentials=credentials) - >>> - >>> data_set.save(data) - >>> reloaded = data_set.load() - >>> - >>> assert data.equals(reloaded) - - """ - - DEFAULT_LOAD_ARGS: Dict[str, Any] = {} - DEFAULT_SAVE_ARGS: Dict[str, Any] = {"index": False} - # using Any because of Sphinx but it should be - # sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine - engines: Dict[str, Any] = {} - - def __init__( - self, - table_name: str, - credentials: Dict[str, Any], - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - ) -> None: - """Creates a new ``SQLTableDataSet``. - - Args: - table_name: The table name to load or save data to. It - overwrites name in ``save_args`` and ``table_name`` - parameters in ``load_args``. - credentials: A dictionary with a ``SQLAlchemy`` connection string. - Users are supposed to provide the connection string 'con' - through credentials. It overwrites `con` parameter in - ``load_args`` and ``save_args`` in case it is provided. To find - all supported connection string formats, see here: - https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls - load_args: Provided to underlying pandas ``read_sql_table`` - function along with the connection string. - To find all supported arguments, see here: - https://pandas.pydata.org/pandas-docs/stable/generated/pandas.read_sql_table.html - To find all supported connection string formats, see here: - https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls - save_args: Provided to underlying pandas ``to_sql`` function along - with the connection string. - To find all supported arguments, see here: - https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.to_sql.html - To find all supported connection string formats, see here: - https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls - It has ``index=False`` in the default parameters. - - Raises: - DatasetError: When either ``table_name`` or ``con`` is empty. - """ - - if not table_name: - raise DatasetError("'table_name' argument cannot be empty.") - - if not (credentials and "con" in credentials and credentials["con"]): - raise DatasetError( - "'con' argument cannot be empty. Please " - "provide a SQLAlchemy connection string." - ) - - # Handle default load and save arguments - self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - self._load_args["table_name"] = table_name - self._save_args["name"] = table_name - - self._connection_str = credentials["con"] - self.create_connection(self._connection_str) - - @classmethod - def create_connection(cls, connection_str: str) -> None: - """Given a connection string, create singleton connection - to be used across all instances of `SQLTableDataSet` that - need to connect to the same source. - """ - if connection_str in cls.engines: - return - - try: - engine = create_engine(connection_str) - except ImportError as import_error: - raise _get_missing_module_error(import_error) from import_error - except NoSuchModuleError as exc: - raise _get_sql_alchemy_missing_error() from exc - - cls.engines[connection_str] = engine - - def _describe(self) -> Dict[str, Any]: - load_args = copy.deepcopy(self._load_args) - save_args = copy.deepcopy(self._save_args) - del load_args["table_name"] - del save_args["name"] - return { - "table_name": self._load_args["table_name"], - "load_args": load_args, - "save_args": save_args, - } - - def _load(self) -> pd.DataFrame: - engine = self.engines[self._connection_str] # type:ignore - return pd.read_sql_table(con=engine, **self._load_args) - - def _save(self, data: pd.DataFrame) -> None: - engine = self.engines[self._connection_str] # type: ignore - data.to_sql(con=engine, **self._save_args) - - def _exists(self) -> bool: - eng = self.engines[self._connection_str] # type: ignore - schema = self._load_args.get("schema", None) - exists = self._load_args["table_name"] in eng.table_names(schema) - return exists - - -class SQLQueryDataSet(AbstractDataset[None, pd.DataFrame]): - """``SQLQueryDataSet`` loads data from a provided SQL query. It - uses ``pandas.DataFrame`` internally, so it supports all allowed - pandas options on ``read_sql_query``. Since Pandas uses SQLAlchemy behind - the scenes, when instantiating ``SQLQueryDataSet`` one needs to pass - a compatible connection string either in ``credentials`` (see the example - code snippet below) or in ``load_args``. Connection string formats supported - by SQLAlchemy can be found here: - https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls - - It does not support save method so it is a read only data set. - To save data to a SQL server use ``SQLTableDataSet``. - - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - shuttle_id_dataset: - type: pandas.SQLQueryDataSet - sql: "select shuttle, shuttle_id from spaceflights.shuttles;" - credentials: db_credentials - - Advanced example using the ``stream_results`` and ``chunksize`` options to reduce memory usage: - - .. code-block:: yaml - - shuttle_id_dataset: - type: pandas.SQLQueryDataSet - sql: "select shuttle, shuttle_id from spaceflights.shuttles;" - credentials: db_credentials - execution_options: - stream_results: true - load_args: - chunksize: 1000 - - Sample database credentials entry in ``credentials.yml``: - - .. code-block:: yaml - - db_credentials: - con: postgresql://scott:tiger@localhost/test - - Example usage for the - `Python API `_: - :: - - - >>> from kedro.extras.datasets.pandas import SQLQueryDataSet - >>> import pandas as pd - >>> - >>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], - >>> "col3": [5, 6]}) - >>> sql = "SELECT * FROM table_a" - >>> credentials = { - >>> "con": "postgresql://scott:tiger@localhost/test" - >>> } - >>> data_set = SQLQueryDataSet(sql=sql, - >>> credentials=credentials) - >>> - >>> sql_data = data_set.load() - - """ - - # using Any because of Sphinx but it should be - # sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine - engines: Dict[str, Any] = {} - - def __init__( # noqa: too-many-arguments - self, - sql: str = None, - credentials: Dict[str, Any] = None, - load_args: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - filepath: str = None, - execution_options: Optional[Dict[str, Any]] = None, - ) -> None: - """Creates a new ``SQLQueryDataSet``. - - Args: - sql: The sql query statement. - credentials: A dictionary with a ``SQLAlchemy`` connection string. - Users are supposed to provide the connection string 'con' - through credentials. It overwrites `con` parameter in - ``load_args`` and ``save_args`` in case it is provided. To find - all supported connection string formats, see here: - https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls - load_args: Provided to underlying pandas ``read_sql_query`` - function along with the connection string. - To find all supported arguments, see here: - https://pandas.pydata.org/pandas-docs/stable/generated/pandas.read_sql_query.html - To find all supported connection string formats, see here: - https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested keys - `open_args_load` and `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set to `r` when loading. - filepath: A path to a file with a sql query statement. - execution_options: A dictionary with non-SQL advanced options for the connection to - be applied to the underlying engine. To find all supported execution - options, see here: - https://docs.sqlalchemy.org/en/12/core/connections.html#sqlalchemy.engine.Connection.execution_options - Note that this is not a standard argument supported by pandas API, but could be - useful for handling large datasets. - - Raises: - DatasetError: When either ``sql`` or ``con`` parameters is empty. - """ - if sql and filepath: - raise DatasetError( - "'sql' and 'filepath' arguments cannot both be provided." - "Please only provide one." - ) - - if not (sql or filepath): - raise DatasetError( - "'sql' and 'filepath' arguments cannot both be empty." - "Please provide a sql query or path to a sql query file." - ) - - if not (credentials and "con" in credentials and credentials["con"]): - raise DatasetError( - "'con' argument cannot be empty. Please " - "provide a SQLAlchemy connection string." - ) - - default_load_args = {} # type: Dict[str, Any] - - self._load_args = ( - {**default_load_args, **load_args} - if load_args is not None - else default_load_args - ) - - # load sql query from file - if sql: - self._load_args["sql"] = sql - self._filepath = None - else: - # filesystem for loading sql file - _fs_args = copy.deepcopy(fs_args) or {} - _fs_credentials = _fs_args.pop("credentials", {}) - protocol, path = get_protocol_and_path(str(filepath)) - - self._protocol = protocol - self._fs = fsspec.filesystem(self._protocol, **_fs_credentials, **_fs_args) - self._filepath = path - self._connection_str = credentials["con"] - self._execution_options = execution_options or {} - self.create_connection(self._connection_str) - - @classmethod - def create_connection(cls, connection_str: str) -> None: - """Given a connection string, create singleton connection - to be used across all instances of `SQLQueryDataSet` that - need to connect to the same source. - """ - if connection_str in cls.engines: - return - - try: - engine = create_engine(connection_str) - except ImportError as import_error: - raise _get_missing_module_error(import_error) from import_error - except NoSuchModuleError as exc: - raise _get_sql_alchemy_missing_error() from exc - - cls.engines[connection_str] = engine - - def _describe(self) -> Dict[str, Any]: - load_args = copy.deepcopy(self._load_args) - return { - "sql": str(load_args.pop("sql", None)), - "filepath": str(self._filepath), - "load_args": str(load_args), - "execution_options": str(self._execution_options), - } - - def _load(self) -> pd.DataFrame: - load_args = copy.deepcopy(self._load_args) - engine = self.engines[self._connection_str].execution_options( - **self._execution_options - ) # type: ignore - - if self._filepath: - load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol) - with self._fs.open(load_path, mode="r") as fs_file: - load_args["sql"] = fs_file.read() - - return pd.read_sql_query(con=engine, **load_args) - - def _save(self, data: None) -> NoReturn: - raise DatasetError("'save' is not supported on SQLQueryDataSet") diff --git a/kedro/extras/datasets/pandas/xml_dataset.py b/kedro/extras/datasets/pandas/xml_dataset.py deleted file mode 100644 index 30bd777252..0000000000 --- a/kedro/extras/datasets/pandas/xml_dataset.py +++ /dev/null @@ -1,171 +0,0 @@ -"""``XMLDataSet`` loads/saves data from/to a XML file using an underlying -filesystem (e.g.: local, S3, GCS). It uses pandas to handle the XML file. -""" -import logging -from copy import deepcopy -from io import BytesIO -from pathlib import PurePosixPath -from typing import Any, Dict - -import fsspec -import pandas as pd - -from kedro.io.core import ( - PROTOCOL_DELIMITER, - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -logger = logging.getLogger(__name__) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class XMLDataSet(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]): - """``XMLDataSet`` loads/saves data from/to a XML file using an underlying - filesystem (e.g.: local, S3, GCS). It uses pandas to handle the XML file. - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.pandas import XMLDataSet - >>> import pandas as pd - >>> - >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) - >>> - >>> data_set = XMLDataSet(filepath="test.xml") - >>> data_set.save(data) - >>> reloaded = data_set.load() - >>> assert data.equals(reloaded) - - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {"index": False} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``XMLDataSet`` pointing to a concrete XML file - on a specific filesystem. - - Args: - filepath: Filepath in POSIX format to a XML file prefixed with a protocol like `s3://`. - If prefix is not provided, `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - Note: `http(s)` doesn't support versioning. - load_args: Pandas options for loading XML files. - Here you can find all available arguments: - https://pandas.pydata.org/docs/reference/api/pandas.read_xml.html - All defaults are preserved. - save_args: Pandas options for saving XML files. - Here you can find all available arguments: - https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_xml.html - All defaults are preserved, but "index", which is set to False. - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``). - """ - _fs_args = deepcopy(fs_args) or {} - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._storage_options = {**_credentials, **_fs_args} - self._fs = fsspec.filesystem(self._protocol, **self._storage_options) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - # Handle default load and save arguments - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - if "storage_options" in self._save_args or "storage_options" in self._load_args: - logger.warning( - "Dropping 'storage_options' for %s, " - "please specify them under 'fs_args' or 'credentials'.", - self._filepath, - ) - self._save_args.pop("storage_options", None) - self._load_args.pop("storage_options", None) - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._load_args, - "save_args": self._save_args, - "version": self._version, - } - - def _load(self) -> pd.DataFrame: - load_path = str(self._get_load_path()) - if self._protocol == "file": - # file:// protocol seems to misbehave on Windows - # (), - # so we don't join that back to the filepath; - # storage_options also don't work with local paths - return pd.read_xml(load_path, **self._load_args) - - load_path = f"{self._protocol}{PROTOCOL_DELIMITER}{load_path}" - return pd.read_xml( - load_path, storage_options=self._storage_options, **self._load_args - ) - - def _save(self, data: pd.DataFrame) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - - buf = BytesIO() - data.to_xml(path_or_buffer=buf, **self._save_args) - - with self._fs.open(save_path, mode="wb") as fs_file: - fs_file.write(buf.getvalue()) - - self._invalidate_cache() - - def _exists(self) -> bool: - try: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DatasetError: - return False - - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/pickle/__init__.py b/kedro/extras/datasets/pickle/__init__.py deleted file mode 100644 index 40b898eb07..0000000000 --- a/kedro/extras/datasets/pickle/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""``AbstractDataset`` implementation to load/save data from/to a Pickle file.""" - -__all__ = ["PickleDataSet"] - -from contextlib import suppress - -with suppress(ImportError): - from .pickle_dataset import PickleDataSet diff --git a/kedro/extras/datasets/pickle/pickle_dataset.py b/kedro/extras/datasets/pickle/pickle_dataset.py deleted file mode 100644 index 93bbbc2dbc..0000000000 --- a/kedro/extras/datasets/pickle/pickle_dataset.py +++ /dev/null @@ -1,245 +0,0 @@ -"""``PickleDataSet`` loads/saves data from/to a Pickle file using an underlying -filesystem (e.g.: local, S3, GCS). The underlying functionality is supported by -the specified backend library passed in (defaults to the ``pickle`` library), so it -supports all allowed options for loading and saving pickle files. -""" -import importlib -from copy import deepcopy -from pathlib import PurePosixPath -from typing import Any, Dict - -import fsspec - -from kedro.io.core import ( - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class PickleDataSet(AbstractVersionedDataset[Any, Any]): - """``PickleDataSet`` loads/saves data from/to a Pickle file using an underlying - filesystem (e.g.: local, S3, GCS). The underlying functionality is supported by - the specified backend library passed in (defaults to the ``pickle`` library), so it - supports all allowed options for loading and saving pickle files. - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - test_model: # simple example without compression - type: pickle.PickleDataSet - filepath: data/07_model_output/test_model.pkl - backend: pickle - - final_model: # example with load and save args - type: pickle.PickleDataSet - filepath: s3://your_bucket/final_model.pkl.lz4 - backend: joblib - credentials: s3_credentials - save_args: - compress: lz4 - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.pickle import PickleDataSet - >>> import pandas as pd - >>> - >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) - >>> - >>> data_set = PickleDataSet(filepath="test.pkl", backend="pickle") - >>> data_set.save(data) - >>> reloaded = data_set.load() - >>> assert data.equals(reloaded) - >>> - >>> data_set = PickleDataSet(filepath="test.pickle.lz4", - >>> backend="compress_pickle", - >>> load_args={"compression":"lz4"}, - >>> save_args={"compression":"lz4"}) - >>> data_set.save(data) - >>> reloaded = data_set.load() - >>> assert data.equals(reloaded) - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments,too-many-locals - self, - filepath: str, - backend: str = "pickle", - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``PickleDataSet`` pointing to a concrete Pickle - file on a specific filesystem. ``PickleDataSet`` supports custom backends to - serialise/deserialise objects. - - Example backends that are compatible (non-exhaustive): - * `pickle` - * `joblib` - * `dill` - * `compress_pickle` - - Example backends that are incompatible: - * `torch` - - Args: - filepath: Filepath in POSIX format to a Pickle file prefixed with a protocol like - `s3://`. If prefix is not provided, `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - Note: `http(s)` doesn't support versioning. - backend: Backend to use, must be an import path to a module which satisfies the - ``pickle`` interface. That is, contains a `load` and `dump` function. - Defaults to 'pickle'. - load_args: Pickle options for loading pickle files. - You can pass in arguments that the backend load function specified accepts, e.g: - pickle.load: https://docs.python.org/3/library/pickle.html#pickle.load - joblib.load: https://joblib.readthedocs.io/en/latest/generated/joblib.load.html - dill.load: https://dill.readthedocs.io/en/latest/index.html#dill.load - compress_pickle.load: - https://lucianopaz.github.io/compress_pickle/html/api/compress_pickle.html#compress_pickle.compress_pickle.load - All defaults are preserved. - save_args: Pickle options for saving pickle files. - You can pass in arguments that the backend dump function specified accepts, e.g: - pickle.dump: https://docs.python.org/3/library/pickle.html#pickle.dump - joblib.dump: https://joblib.readthedocs.io/en/latest/generated/joblib.dump.html - dill.dump: https://dill.readthedocs.io/en/latest/index.html#dill.dump - compress_pickle.dump: - https://lucianopaz.github.io/compress_pickle/html/api/compress_pickle.html#compress_pickle.compress_pickle.dump - All defaults are preserved. - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested keys - `open_args_load` and `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set to `wb` when saving. - - Raises: - ValueError: If ``backend`` does not satisfy the `pickle` interface. - ImportError: If the ``backend`` module could not be imported. - """ - # We do not store `imported_backend` as an attribute to be used in `load`/`save` - # as this would mean the dataset cannot be deepcopied (module objects cannot be - # pickled). The import here is purely to raise any errors as early as possible. - # Repeated imports in the `load` and `save` methods should not be a significant - # performance hit as Python caches imports. - try: - imported_backend = importlib.import_module(backend) - except ImportError as exc: - raise ImportError( - f"Selected backend '{backend}' could not be imported. " - "Make sure it is installed and importable." - ) from exc - - if not ( - hasattr(imported_backend, "load") and hasattr(imported_backend, "dump") - ): - raise ValueError( - f"Selected backend '{backend}' should satisfy the pickle interface. " - "Missing one of 'load' and 'dump' on the backend." - ) - - _fs_args = deepcopy(fs_args) or {} - _fs_open_args_load = _fs_args.pop("open_args_load", {}) - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - self._backend = backend - - # Handle default load and save arguments - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - _fs_open_args_save.setdefault("mode", "wb") - self._fs_open_args_load = _fs_open_args_load - self._fs_open_args_save = _fs_open_args_save - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "backend": self._backend, - "protocol": self._protocol, - "load_args": self._load_args, - "save_args": self._save_args, - "version": self._version, - } - - def _load(self) -> Any: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - - with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: - imported_backend = importlib.import_module(self._backend) - return imported_backend.load(fs_file, **self._load_args) # type: ignore - - def _save(self, data: Any) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - try: - imported_backend = importlib.import_module(self._backend) - imported_backend.dump(data, fs_file, **self._save_args) # type: ignore - except Exception as exc: - raise DatasetError( - f"{data.__class__} was not serialised due to: {exc}" - ) from exc - - self._invalidate_cache() - - def _exists(self) -> bool: - try: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DatasetError: - return False - - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/pillow/__init__.py b/kedro/extras/datasets/pillow/__init__.py deleted file mode 100644 index 03df85f3ee..0000000000 --- a/kedro/extras/datasets/pillow/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""``AbstractDataset`` implementation to load/save image data.""" - -__all__ = ["ImageDataSet"] - -from contextlib import suppress - -with suppress(ImportError): - from .image_dataset import ImageDataSet diff --git a/kedro/extras/datasets/pillow/image_dataset.py b/kedro/extras/datasets/pillow/image_dataset.py deleted file mode 100644 index a403b74b27..0000000000 --- a/kedro/extras/datasets/pillow/image_dataset.py +++ /dev/null @@ -1,143 +0,0 @@ -"""``ImageDataSet`` loads/saves image data as `numpy` from an underlying -filesystem (e.g.: local, S3, GCS). It uses Pillow to handle image file. -""" -from copy import deepcopy -from pathlib import PurePosixPath -from typing import Any, Dict - -import fsspec -from PIL import Image - -from kedro.io.core import ( - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class ImageDataSet(AbstractVersionedDataset[Image.Image, Image.Image]): - """``ImageDataSet`` loads/saves image data as `numpy` from an underlying - filesystem (e.g.: local, S3, GCS). It uses Pillow to handle image file. - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.pillow import ImageDataSet - >>> - >>> data_set = ImageDataSet(filepath="test.png") - >>> image = data_set.load() - >>> image.show() - - """ - - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``ImageDataSet`` pointing to a concrete image file - on a specific filesystem. - - Args: - filepath: Filepath in POSIX format to an image file prefixed with a protocol like - `s3://`. If prefix is not provided, `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - Note: `http(s)` doesn't support versioning. - save_args: Pillow options for saving image files. - Here you can find all available arguments: - https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.save - All defaults are preserved. - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested keys - `open_args_load` and `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set to `r` when loading - and to `w` when saving. - """ - _fs_args = deepcopy(fs_args) or {} - _fs_open_args_load = _fs_args.pop("open_args_load", {}) - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - # Handle default save argument - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - _fs_open_args_save.setdefault("mode", "wb") - self._fs_open_args_load = _fs_open_args_load - self._fs_open_args_save = _fs_open_args_save - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._protocol, - "save_args": self._save_args, - "version": self._version, - } - - def _load(self) -> Image.Image: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - - with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: - return Image.open(fs_file).copy() - - def _save(self, data: Image.Image) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - data.save(fs_file, **self._save_args) - - self._invalidate_cache() - - def _exists(self) -> bool: - try: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DatasetError: - return False - - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/plotly/__init__.py b/kedro/extras/datasets/plotly/__init__.py deleted file mode 100644 index c2851bb000..0000000000 --- a/kedro/extras/datasets/plotly/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""``AbstractDataset`` implementations to load/save a plotly figure from/to a JSON -file.""" - -__all__ = ["PlotlyDataSet", "JSONDataSet"] - -from contextlib import suppress - -with suppress(ImportError): - from .plotly_dataset import PlotlyDataSet -with suppress(ImportError): - from .json_dataset import JSONDataSet diff --git a/kedro/extras/datasets/plotly/json_dataset.py b/kedro/extras/datasets/plotly/json_dataset.py deleted file mode 100644 index 3c686ab896..0000000000 --- a/kedro/extras/datasets/plotly/json_dataset.py +++ /dev/null @@ -1,167 +0,0 @@ -"""``JSONDataSet`` loads/saves a plotly figure from/to a JSON file using an underlying -filesystem (e.g.: local, S3, GCS). -""" -from copy import deepcopy -from pathlib import PurePosixPath -from typing import Any, Dict, Union - -import fsspec -import plotly.io as pio -from plotly import graph_objects as go - -from kedro.io.core import ( - AbstractVersionedDataset, - Version, - get_filepath_str, - get_protocol_and_path, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class JSONDataSet( - AbstractVersionedDataset[go.Figure, Union[go.Figure, go.FigureWidget]] -): - """``JSONDataSet`` loads/saves a plotly figure from/to a JSON file using an - underlying filesystem (e.g.: local, S3, GCS). - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - scatter_plot: - type: plotly.JSONDataSet - filepath: data/08_reporting/scatter_plot.json - save_args: - engine: auto - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.plotly import JSONDataSet - >>> import plotly.express as px - >>> - >>> fig = px.bar(x=["a", "b", "c"], y=[1, 3, 2]) - >>> data_set = JSONDataSet(filepath="test.json") - >>> data_set.save(fig) - >>> reloaded = data_set.load() - >>> assert fig == reloaded - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``JSONDataSet`` pointing to a concrete JSON file - on a specific filesystem. - - Args: - filepath: Filepath in POSIX format to a JSON file prefixed with a protocol like `s3://`. - If prefix is not provided `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - Note: `http(s)` doesn't support versioning. - load_args: Plotly options for loading JSON files. - Here you can find all available arguments: - https://plotly.com/python-api-reference/generated/plotly.io.from_json.html#plotly.io.from_json - All defaults are preserved. - save_args: Plotly options for saving JSON files. - Here you can find all available arguments: - https://plotly.com/python-api-reference/generated/plotly.io.write_json.html - All defaults are preserved. - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{'token': None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested keys - `open_args_load` and `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set to `w` when - saving. - """ - _fs_args = deepcopy(fs_args) or {} - _fs_open_args_load = _fs_args.pop("open_args_load", {}) - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - # Handle default load and save arguments - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - _fs_open_args_save.setdefault("mode", "w") - self._fs_open_args_load = _fs_open_args_load - self._fs_open_args_save = _fs_open_args_save - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._load_args, - "save_args": self._save_args, - "version": self._version, - } - - def _load(self) -> Union[go.Figure, go.FigureWidget]: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - - with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: - # read_json doesn't work correctly with file handler, so we have to read - # the file, decode it manually and pass to the low-level from_json instead. - return pio.from_json(str(fs_file.read(), "utf-8"), **self._load_args) - - def _save(self, data: go.Figure) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - data.write_json(fs_file, **self._save_args) - - self._invalidate_cache() - - def _exists(self) -> bool: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/plotly/plotly_dataset.py b/kedro/extras/datasets/plotly/plotly_dataset.py deleted file mode 100644 index 7cb6477b25..0000000000 --- a/kedro/extras/datasets/plotly/plotly_dataset.py +++ /dev/null @@ -1,142 +0,0 @@ -"""``PlotlyDataSet`` generates a plot from a pandas DataFrame and saves it to a JSON -file using an underlying filesystem (e.g.: local, S3, GCS). It loads the JSON into a -plotly figure. -""" -from copy import deepcopy -from typing import Any, Dict - -import pandas as pd -import plotly.express as px -from plotly import graph_objects as go - -from kedro.io.core import Version - -from .json_dataset import JSONDataSet - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class PlotlyDataSet(JSONDataSet): - """``PlotlyDataSet`` generates a plot from a pandas DataFrame and saves it to a JSON - file using an underlying filesystem (e.g.: local, S3, GCS). It loads the JSON into a - plotly figure. - - ``PlotlyDataSet`` is a convenience wrapper for ``plotly.JSONDataSet``. It generates - the JSON file directly from a pandas DataFrame through ``plotly_args``. - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - bar_plot: - type: plotly.PlotlyDataSet - filepath: data/08_reporting/bar_plot.json - plotly_args: - type: bar - fig: - x: features - y: importance - orientation: h - layout: - xaxis_title: x - yaxis_title: y - title: Title - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.plotly import PlotlyDataSet - >>> import plotly.express as px - >>> import pandas as pd - >>> - >>> df_data = pd.DataFrame([[0, 1], [1, 0]], columns=('x1', 'x2')) - >>> - >>> data_set = PlotlyDataSet( - >>> filepath='scatter_plot.json', - >>> plotly_args={ - >>> 'type': 'scatter', - >>> 'fig': {'x': 'x1', 'y': 'x2'}, - >>> } - >>> ) - >>> data_set.save(df_data) - >>> reloaded = data_set.load() - >>> assert px.scatter(df_data, x='x1', y='x2') == reloaded - - """ - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - plotly_args: Dict[str, Any], - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``PlotlyDataSet`` pointing to a concrete JSON file - on a specific filesystem. - - Args: - filepath: Filepath in POSIX format to a JSON file prefixed with a protocol like `s3://`. - If prefix is not provided `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - Note: `http(s)` doesn't support versioning. - plotly_args: Plotly configuration for generating a plotly figure from the - dataframe. Keys are `type` (plotly express function, e.g. bar, - line, scatter), `fig` (kwargs passed to the plotting function), theme - (defaults to `plotly`), `layout`. - load_args: Plotly options for loading JSON files. - Here you can find all available arguments: - https://plotly.com/python-api-reference/generated/plotly.io.from_json.html#plotly.io.from_json - All defaults are preserved. - save_args: Plotly options for saving JSON files. - Here you can find all available arguments: - https://plotly.com/python-api-reference/generated/plotly.io.write_json.html - All defaults are preserved. - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{'token': None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested keys - `open_args_load` and `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set to `w` when saving. - """ - super().__init__(filepath, load_args, save_args, version, credentials, fs_args) - self._plotly_args = plotly_args - - _fs_args = deepcopy(fs_args) or {} - _fs_open_args_load = _fs_args.pop("open_args_load", {}) - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _fs_open_args_save.setdefault("mode", "w") - - self._fs_open_args_load = _fs_open_args_load - self._fs_open_args_save = _fs_open_args_save - - def _describe(self) -> Dict[str, Any]: - return {**super()._describe(), "plotly_args": self._plotly_args} - - def _save(self, data: pd.DataFrame) -> None: - fig = self._plot_dataframe(data) - super()._save(fig) - - def _plot_dataframe(self, data: pd.DataFrame) -> go.Figure: - plot_type = self._plotly_args.get("type") - fig_params = self._plotly_args.get("fig", {}) - fig = getattr(px, plot_type)(data, **fig_params) # type: ignore - fig.update_layout(template=self._plotly_args.get("theme", "plotly")) - fig.update_layout(self._plotly_args.get("layout", {})) - return fig diff --git a/kedro/extras/datasets/redis/__init__.py b/kedro/extras/datasets/redis/__init__.py deleted file mode 100644 index f3c553ec3b..0000000000 --- a/kedro/extras/datasets/redis/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""``AbstractDataset`` implementation to load/save data from/to a redis db.""" - -__all__ = ["PickleDataSet"] - -from contextlib import suppress - -with suppress(ImportError): - from .redis_dataset import PickleDataSet diff --git a/kedro/extras/datasets/redis/redis_dataset.py b/kedro/extras/datasets/redis/redis_dataset.py deleted file mode 100644 index d4d7b11f74..0000000000 --- a/kedro/extras/datasets/redis/redis_dataset.py +++ /dev/null @@ -1,191 +0,0 @@ -"""``PickleDataSet`` loads/saves data from/to a Redis database. The underlying -functionality is supported by the redis library, so it supports all allowed -options for instantiating the redis app ``from_url`` and setting a value.""" - -import importlib -import os -from copy import deepcopy -from typing import Any, Dict - -import redis - -from kedro.io.core import AbstractDataset, DatasetError - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class PickleDataSet(AbstractDataset[Any, Any]): - """``PickleDataSet`` loads/saves data from/to a Redis database. The - underlying functionality is supported by the redis library, so it supports - all allowed options for instantiating the redis app ``from_url`` and setting - a value. - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - my_python_object: # simple example - type: redis.PickleDataSet - key: my_object - from_url_args: - url: redis://127.0.0.1:6379 - - final_python_object: # example with save args - type: redis.PickleDataSet - key: my_final_object - from_url_args: - url: redis://127.0.0.1:6379 - db: 1 - save_args: - ex: 10 - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.redis import PickleDataSet - >>> import pandas as pd - >>> - >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) - >>> - >>> my_data = PickleDataSet(key="my_data") - >>> my_data.save(data) - >>> reloaded = my_data.load() - >>> assert data.equals(reloaded) - """ - - DEFAULT_REDIS_URL = os.getenv("REDIS_URL", "redis://127.0.0.1:6379") - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - key: str, - backend: str = "pickle", - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - credentials: Dict[str, Any] = None, - redis_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``PickleDataSet``. This loads/saves data from/to - a Redis database while deserialising/serialising. Supports custom backends to - serialise/deserialise objects. - - Example backends that are compatible (non-exhaustive): - * `pickle` - * `dill` - * `compress_pickle` - - Example backends that are incompatible: - * `torch` - - Args: - key: The key to use for saving/loading object to Redis. - backend: Backend to use, must be an import path to a module which satisfies the - ``pickle`` interface. That is, contains a `loads` and `dumps` function. - Defaults to 'pickle'. - load_args: Pickle options for loading pickle files. - You can pass in arguments that the backend load function specified accepts, e.g: - pickle.loads: https://docs.python.org/3/library/pickle.html#pickle.loads - dill.loads: https://dill.readthedocs.io/en/latest/index.html#dill.loads - compress_pickle.loads: - https://lucianopaz.github.io/compress_pickle/html/api/compress_pickle.html#compress_pickle.compress_pickle.loads - All defaults are preserved. - save_args: Pickle options for saving pickle files. - You can pass in arguments that the backend dump function specified accepts, e.g: - pickle.dumps: https://docs.python.org/3/library/pickle.html#pickle.dump - dill.dumps: https://dill.readthedocs.io/en/latest/index.html#dill.dumps - compress_pickle.dumps: - https://lucianopaz.github.io/compress_pickle/html/api/compress_pickle.html#compress_pickle.compress_pickle.dumps - All defaults are preserved. - credentials: Credentials required to get access to the redis server. - E.g. `{"password": None}`. - redis_args: Extra arguments to pass into the redis client constructor - ``redis.StrictRedis.from_url``. (e.g. `{"socket_timeout": 10}`), as well as to pass - to the ``redis.StrictRedis.set`` through nested keys `from_url_args` and `set_args`. - Here you can find all available arguments for `from_url`: - https://redis-py.readthedocs.io/en/stable/connections.html?highlight=from_url#redis.Redis.from_url - All defaults are preserved, except `url`, which is set to `redis://127.0.0.1:6379`. - You could also specify the url through the env variable ``REDIS_URL``. - - Raises: - ValueError: If ``backend`` does not satisfy the `pickle` interface. - ImportError: If the ``backend`` module could not be imported. - """ - try: - imported_backend = importlib.import_module(backend) - except ImportError as exc: - raise ImportError( - f"Selected backend '{backend}' could not be imported. " - "Make sure it is installed and importable." - ) from exc - - if not ( - hasattr(imported_backend, "loads") and hasattr(imported_backend, "dumps") - ): - raise ValueError( - f"Selected backend '{backend}' should satisfy the pickle interface. " - "Missing one of 'loads' and 'dumps' on the backend." - ) - - self._backend = backend - - self._key = key - - _redis_args = deepcopy(redis_args) or {} - self._redis_from_url_args = _redis_args.pop("from_url_args", {}) - self._redis_from_url_args.setdefault("url", self.DEFAULT_REDIS_URL) - self._redis_set_args = _redis_args.pop("set_args", {}) - _credentials = deepcopy(credentials) or {} - - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - self._redis_db = redis.Redis.from_url( - **self._redis_from_url_args, **_credentials - ) - - def _describe(self) -> Dict[str, Any]: - return {"key": self._key, **self._redis_from_url_args} - - # `redis_db` mypy does not work since it is optional and optional is not - # accepted by pickle.loads. - def _load(self) -> Any: - if not self.exists(): - raise DatasetError(f"The provided key {self._key} does not exists.") - imported_backend = importlib.import_module(self._backend) - return imported_backend.loads( # type: ignore - self._redis_db.get(self._key), **self._load_args - ) # type: ignore - - def _save(self, data: Any) -> None: - try: - imported_backend = importlib.import_module(self._backend) - self._redis_db.set( - self._key, - imported_backend.dumps(data, **self._save_args), # type: ignore - **self._redis_set_args, - ) - except Exception as exc: - raise DatasetError( - f"{data.__class__} was not serialised due to: {exc}" - ) from exc - - def _exists(self) -> bool: - try: - return bool(self._redis_db.exists(self._key)) - except Exception as exc: - raise DatasetError( - f"The existence of key {self._key} could not be established due to: {exc}" - ) from exc diff --git a/kedro/extras/datasets/spark/__init__.py b/kedro/extras/datasets/spark/__init__.py deleted file mode 100644 index 3dede09aa8..0000000000 --- a/kedro/extras/datasets/spark/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Provides I/O modules for Apache Spark.""" - -__all__ = ["SparkDataSet", "SparkHiveDataSet", "SparkJDBCDataSet", "DeltaTableDataSet"] - -from contextlib import suppress - -with suppress(ImportError): - from .spark_dataset import SparkDataSet -with suppress(ImportError): - from .spark_hive_dataset import SparkHiveDataSet -with suppress(ImportError): - from .spark_jdbc_dataset import SparkJDBCDataSet -with suppress(ImportError): - from .deltatable_dataset import DeltaTableDataSet diff --git a/kedro/extras/datasets/spark/deltatable_dataset.py b/kedro/extras/datasets/spark/deltatable_dataset.py deleted file mode 100644 index 6df51fcdd7..0000000000 --- a/kedro/extras/datasets/spark/deltatable_dataset.py +++ /dev/null @@ -1,116 +0,0 @@ -"""``AbstractDataset`` implementation to access DeltaTables using -``delta-spark`` -""" -from pathlib import PurePosixPath -from typing import NoReturn - -from delta.tables import DeltaTable -from pyspark.sql import SparkSession -from pyspark.sql.utils import AnalysisException - -from kedro.extras.datasets.spark.spark_dataset import ( - _split_filepath, - _strip_dbfs_prefix, -) -from kedro.io.core import AbstractDataset, DatasetError - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class DeltaTableDataSet(AbstractDataset[None, DeltaTable]): - """``DeltaTableDataSet`` loads data into DeltaTable objects. - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - weather@spark: - type: spark.SparkDataSet - filepath: data/02_intermediate/data.parquet - file_format: "delta" - - weather@delta: - type: spark.DeltaTableDataSet - filepath: data/02_intermediate/data.parquet - - Example usage for the - `Python API `_: - :: - - >>> from pyspark.sql import SparkSession - >>> from pyspark.sql.types import (StructField, StringType, - >>> IntegerType, StructType) - >>> - >>> from kedro.extras.datasets.spark import DeltaTableDataSet, SparkDataSet - >>> - >>> schema = StructType([StructField("name", StringType(), True), - >>> StructField("age", IntegerType(), True)]) - >>> - >>> data = [('Alex', 31), ('Bob', 12), ('Clarke', 65), ('Dave', 29)] - >>> - >>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema) - >>> - >>> data_set = SparkDataSet(filepath="test_data", file_format="delta") - >>> data_set.save(spark_df) - >>> deltatable_dataset = DeltaTableDataSet(filepath="test_data") - >>> delta_table = deltatable_dataset.load() - >>> - >>> delta_table.update() - """ - - # this dataset cannot be used with ``ParallelRunner``, - # therefore it has the attribute ``_SINGLE_PROCESS = True`` - # for parallelism within a Spark pipeline please consider - # using ``ThreadRunner`` instead - _SINGLE_PROCESS = True - - def __init__(self, filepath: str) -> None: - """Creates a new instance of ``DeltaTableDataSet``. - - Args: - filepath: Filepath in POSIX format to a Spark dataframe. When using Databricks - and working with data written to mount path points, - specify ``filepath``s for (versioned) ``SparkDataSet``s - starting with ``/dbfs/mnt``. - """ - fs_prefix, filepath = _split_filepath(filepath) - - self._fs_prefix = fs_prefix - self._filepath = PurePosixPath(filepath) - - @staticmethod - def _get_spark(): - return SparkSession.builder.getOrCreate() - - def _load(self) -> DeltaTable: - load_path = self._fs_prefix + str(self._filepath) - return DeltaTable.forPath(self._get_spark(), load_path) - - def _save(self, data: None) -> NoReturn: - raise DatasetError(f"{self.__class__.__name__} is a read only dataset type") - - def _exists(self) -> bool: - load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath)) - - try: - self._get_spark().read.load(path=load_path, format="delta") - except AnalysisException as exception: - # `AnalysisException.desc` is deprecated with pyspark >= 3.4 - message = ( - exception.desc if hasattr(exception, "desc") else exception.message - ) - - if "Path does not exist:" in message or "is not a Delta table" in message: - return False - raise - - return True - - def _describe(self): - return {"filepath": str(self._filepath), "fs_prefix": self._fs_prefix} diff --git a/kedro/extras/datasets/spark/spark_dataset.py b/kedro/extras/datasets/spark/spark_dataset.py deleted file mode 100644 index 0547b3e804..0000000000 --- a/kedro/extras/datasets/spark/spark_dataset.py +++ /dev/null @@ -1,427 +0,0 @@ -"""``AbstractVersionedDataset`` implementation to access Spark dataframes using -``pyspark`` -""" -import json -from copy import deepcopy -from fnmatch import fnmatch -from functools import partial -from pathlib import PurePosixPath -from typing import Any, Dict, List, Optional, Tuple -from warnings import warn - -import fsspec -from hdfs import HdfsError, InsecureClient -from pyspark.sql import DataFrame, SparkSession -from pyspark.sql.types import StructType -from pyspark.sql.utils import AnalysisException -from s3fs import S3FileSystem - -from kedro.io.core import ( - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -def _parse_glob_pattern(pattern: str) -> str: - special = ("*", "?", "[") - clean = [] - for part in pattern.split("/"): - if any(char in part for char in special): - break - clean.append(part) - return "/".join(clean) - - -def _split_filepath(filepath: str) -> Tuple[str, str]: - split_ = filepath.split("://", 1) - MIN_SPLIT_SIZE = 2 - if len(split_) == MIN_SPLIT_SIZE: - return split_[0] + "://", split_[1] - return "", split_[0] - - -def _strip_dbfs_prefix(path: str, prefix: str = "/dbfs") -> str: - return path[len(prefix) :] if path.startswith(prefix) else path - - -def _dbfs_glob(pattern: str, dbutils: Any) -> List[str]: - """Perform a custom glob search in DBFS using the provided pattern. - It is assumed that version paths are managed by Kedro only. - - Args: - pattern: Glob pattern to search for. - dbutils: dbutils instance to operate with DBFS. - - Returns: - List of DBFS paths prefixed with '/dbfs' that satisfy the glob pattern. - """ - pattern = _strip_dbfs_prefix(pattern) - prefix = _parse_glob_pattern(pattern) - matched = set() - filename = pattern.split("/")[-1] - - for file_info in dbutils.fs.ls(prefix): - if file_info.isDir(): - path = str( - PurePosixPath(_strip_dbfs_prefix(file_info.path, "dbfs:")) / filename - ) - if fnmatch(path, pattern): - path = "/dbfs" + path - matched.add(path) - return sorted(matched) - - -def _get_dbutils(spark: SparkSession) -> Optional[Any]: - """Get the instance of 'dbutils' or None if the one could not be found.""" - dbutils = globals().get("dbutils") - if dbutils: - return dbutils - - try: - from pyspark.dbutils import DBUtils # pylint: disable=import-outside-toplevel - - dbutils = DBUtils(spark) - except ImportError: - try: - import IPython # pylint: disable=import-outside-toplevel - except ImportError: - pass - else: - ipython = IPython.get_ipython() - dbutils = ipython.user_ns.get("dbutils") if ipython else None - - return dbutils - - -def _dbfs_exists(pattern: str, dbutils: Any) -> bool: - """Perform an `ls` list operation in DBFS using the provided pattern. - It is assumed that version paths are managed by Kedro. - Broad `Exception` is present due to `dbutils.fs.ExecutionError` that - cannot be imported directly. - Args: - pattern: Filepath to search for. - dbutils: dbutils instance to operate with DBFS. - Returns: - Boolean value if filepath exists. - """ - pattern = _strip_dbfs_prefix(pattern) - file = _parse_glob_pattern(pattern) - try: - dbutils.fs.ls(file) - return True - except Exception: # pylint: disable=broad-except - return False - - -class KedroHdfsInsecureClient(InsecureClient): - """Subclasses ``hdfs.InsecureClient`` and implements ``hdfs_exists`` - and ``hdfs_glob`` methods required by ``SparkDataSet``""" - - def hdfs_exists(self, hdfs_path: str) -> bool: - """Determines whether given ``hdfs_path`` exists in HDFS. - - Args: - hdfs_path: Path to check. - - Returns: - True if ``hdfs_path`` exists in HDFS, False otherwise. - """ - return bool(self.status(hdfs_path, strict=False)) - - def hdfs_glob(self, pattern: str) -> List[str]: - """Perform a glob search in HDFS using the provided pattern. - - Args: - pattern: Glob pattern to search for. - - Returns: - List of HDFS paths that satisfy the glob pattern. - """ - prefix = _parse_glob_pattern(pattern) or "/" - matched = set() - try: - for dpath, _, fnames in self.walk(prefix): - if fnmatch(dpath, pattern): - matched.add(dpath) - matched |= { - f"{dpath}/{fname}" - for fname in fnames - if fnmatch(f"{dpath}/{fname}", pattern) - } - except HdfsError: # pragma: no cover - # HdfsError is raised by `self.walk()` if prefix does not exist in HDFS. - # Ignore and return an empty list. - pass - return sorted(matched) - - -class SparkDataSet(AbstractVersionedDataset[DataFrame, DataFrame]): - """``SparkDataSet`` loads and saves Spark dataframes. - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - weather: - type: spark.SparkDataSet - filepath: s3a://your_bucket/data/01_raw/weather/* - file_format: csv - load_args: - header: True - inferSchema: True - save_args: - sep: '|' - header: True - - weather_with_schema: - type: spark.SparkDataSet - filepath: s3a://your_bucket/data/01_raw/weather/* - file_format: csv - load_args: - header: True - schema: - filepath: path/to/schema.json - save_args: - sep: '|' - header: True - - weather_cleaned: - type: spark.SparkDataSet - filepath: data/02_intermediate/data.parquet - file_format: parquet - - Example usage for the - `Python API `_: - :: - - >>> from pyspark.sql import SparkSession - >>> from pyspark.sql.types import (StructField, StringType, - >>> IntegerType, StructType) - >>> - >>> from kedro.extras.datasets.spark import SparkDataSet - >>> - >>> schema = StructType([StructField("name", StringType(), True), - >>> StructField("age", IntegerType(), True)]) - >>> - >>> data = [('Alex', 31), ('Bob', 12), ('Clarke', 65), ('Dave', 29)] - >>> - >>> spark_df = SparkSession.builder.getOrCreate()\ - >>> .createDataFrame(data, schema) - >>> - >>> data_set = SparkDataSet(filepath="test_data") - >>> data_set.save(spark_df) - >>> reloaded = data_set.load() - >>> - >>> reloaded.take(4) - """ - - # this dataset cannot be used with ``ParallelRunner``, - # therefore it has the attribute ``_SINGLE_PROCESS = True`` - # for parallelism within a Spark pipeline please consider - # ``ThreadRunner`` instead - _SINGLE_PROCESS = True - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # ruff: noqa: PLR0913 - self, - filepath: str, - file_format: str = "parquet", - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``SparkDataSet``. - - Args: - filepath: Filepath in POSIX format to a Spark dataframe. When using Databricks - and working with data written to mount path points, - specify ``filepath``s for (versioned) ``SparkDataSet``s - starting with ``/dbfs/mnt``. - file_format: File format used during load and save - operations. These are formats supported by the running - SparkContext include parquet, csv, delta. For a list of supported - formats please refer to Apache Spark documentation at - https://spark.apache.org/docs/latest/sql-programming-guide.html - load_args: Load args passed to Spark DataFrameReader load method. - It is dependent on the selected file format. You can find - a list of read options for each supported format - in Spark DataFrame read documentation: - https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html - save_args: Save args passed to Spark DataFrame write options. - Similar to load_args this is dependent on the selected file - format. You can pass ``mode`` and ``partitionBy`` to specify - your overwrite mode and partitioning respectively. You can find - a list of options for each format in Spark DataFrame - write documentation: - https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials to access the S3 bucket, such as - ``key``, ``secret``, if ``filepath`` prefix is ``s3a://`` or ``s3n://``. - Optional keyword arguments passed to ``hdfs.client.InsecureClient`` - if ``filepath`` prefix is ``hdfs://``. Ignored otherwise. - """ - credentials = deepcopy(credentials) or {} - fs_prefix, filepath = _split_filepath(filepath) - exists_function = None - glob_function = None - - if fs_prefix in ("s3a://", "s3n://"): - if fs_prefix == "s3n://": - warn( - "'s3n' filesystem has now been deprecated by Spark, " - "please consider switching to 's3a'", - DeprecationWarning, - ) - _s3 = S3FileSystem(**credentials) - exists_function = _s3.exists - glob_function = partial(_s3.glob, refresh=True) - path = PurePosixPath(filepath) - - elif fs_prefix == "hdfs://" and version: - warn( - f"HDFS filesystem support for versioned {self.__class__.__name__} is " - f"in beta and uses 'hdfs.client.InsecureClient', please use with " - f"caution" - ) - - # default namenode address - credentials.setdefault("url", "http://localhost:9870") - credentials.setdefault("user", "hadoop") - - _hdfs_client = KedroHdfsInsecureClient(**credentials) - exists_function = _hdfs_client.hdfs_exists - glob_function = _hdfs_client.hdfs_glob # type: ignore - path = PurePosixPath(filepath) - - else: - path = PurePosixPath(filepath) - - if filepath.startswith("/dbfs"): - dbutils = _get_dbutils(self._get_spark()) - if dbutils: - glob_function = partial(_dbfs_glob, dbutils=dbutils) - exists_function = partial(_dbfs_exists, dbutils=dbutils) - - super().__init__( - filepath=path, - version=version, - exists_function=exists_function, - glob_function=glob_function, - ) - - # Handle default load and save arguments - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - # Handle schema load argument - self._schema = self._load_args.pop("schema", None) - if self._schema is not None: - if isinstance(self._schema, dict): - self._schema = self._load_schema_from_file(self._schema) - - self._file_format = file_format - self._fs_prefix = fs_prefix - self._handle_delta_format() - - @staticmethod - def _load_schema_from_file(schema: Dict[str, Any]) -> StructType: - - filepath = schema.get("filepath") - if not filepath: - raise DatasetError( - "Schema load argument does not specify a 'filepath' attribute. Please" - "include a path to a JSON-serialised 'pyspark.sql.types.StructType'." - ) - - credentials = deepcopy(schema.get("credentials")) or {} - protocol, schema_path = get_protocol_and_path(filepath) - file_system = fsspec.filesystem(protocol, **credentials) - pure_posix_path = PurePosixPath(schema_path) - load_path = get_filepath_str(pure_posix_path, protocol) - - # Open schema file - with file_system.open(load_path) as fs_file: - - try: - return StructType.fromJson(json.loads(fs_file.read())) - except Exception as exc: - raise DatasetError( - f"Contents of 'schema.filepath' ({schema_path}) are invalid. Please" - f"provide a valid JSON-serialised 'pyspark.sql.types.StructType'." - ) from exc - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._fs_prefix + str(self._filepath), - "file_format": self._file_format, - "load_args": self._load_args, - "save_args": self._save_args, - "version": self._version, - } - - @staticmethod - def _get_spark(): - return SparkSession.builder.getOrCreate() - - def _load(self) -> DataFrame: - load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path())) - read_obj = self._get_spark().read - - # Pass schema if defined - if self._schema: - read_obj = read_obj.schema(self._schema) - - return read_obj.load(load_path, self._file_format, **self._load_args) - - def _save(self, data: DataFrame) -> None: - save_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_save_path())) - data.write.save(save_path, self._file_format, **self._save_args) - - def _exists(self) -> bool: - load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path())) - - try: - self._get_spark().read.load(load_path, self._file_format) - except AnalysisException as exception: - # `AnalysisException.desc` is deprecated with pyspark >= 3.4 - message = ( - exception.desc if hasattr(exception, "desc") else exception.message - ) - if "Path does not exist:" in message or "is not a Delta table" in message: - return False - raise - return True - - def _handle_delta_format(self) -> None: - supported_modes = {"append", "overwrite", "error", "errorifexists", "ignore"} - write_mode = self._save_args.get("mode") - if ( - write_mode - and self._file_format == "delta" - and write_mode not in supported_modes - ): - raise DatasetError( - f"It is not possible to perform 'save()' for file format 'delta' " - f"with mode '{write_mode}' on 'SparkDataSet'. " - f"Please use 'spark.DeltaTableDataSet' instead." - ) diff --git a/kedro/extras/datasets/spark/spark_hive_dataset.py b/kedro/extras/datasets/spark/spark_hive_dataset.py deleted file mode 100644 index 746f7ae6df..0000000000 --- a/kedro/extras/datasets/spark/spark_hive_dataset.py +++ /dev/null @@ -1,224 +0,0 @@ -"""``AbstractDataset`` implementation to access Spark dataframes using -``pyspark`` on Apache Hive. -""" -import pickle -from copy import deepcopy -from typing import Any, Dict, List - -from pyspark.sql import DataFrame, SparkSession, Window -from pyspark.sql.functions import col, lit, row_number - -from kedro.io.core import AbstractDataset, DatasetError - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -# pylint:disable=too-many-instance-attributes -class SparkHiveDataSet(AbstractDataset[DataFrame, DataFrame]): - """``SparkHiveDataSet`` loads and saves Spark dataframes stored on Hive. - This data set also handles some incompatible file types such as using partitioned parquet on - hive which will not normally allow upserts to existing data without a complete replacement - of the existing file/partition. - - This DataSet has some key assumptions: - - - Schemas do not change during the pipeline run (defined PKs must be present for the - duration of the pipeline) - - Tables are not being externally modified during upserts. The upsert method is NOT ATOMIC - - to external changes to the target table while executing. - Upsert methodology works by leveraging Spark DataFrame execution plan checkpointing. - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - hive_dataset: - type: spark.SparkHiveDataSet - database: hive_database - table: table_name - write_mode: overwrite - - Example usage for the - `Python API `_: - :: - - >>> from pyspark.sql import SparkSession - >>> from pyspark.sql.types import (StructField, StringType, - >>> IntegerType, StructType) - >>> - >>> from kedro.extras.datasets.spark import SparkHiveDataSet - >>> - >>> schema = StructType([StructField("name", StringType(), True), - >>> StructField("age", IntegerType(), True)]) - >>> - >>> data = [('Alex', 31), ('Bob', 12), ('Clarke', 65), ('Dave', 29)] - >>> - >>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema) - >>> - >>> data_set = SparkHiveDataSet(database="test_database", table="test_table", - >>> write_mode="overwrite") - >>> data_set.save(spark_df) - >>> reloaded = data_set.load() - >>> - >>> reloaded.take(4) - """ - - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - database: str, - table: str, - write_mode: str = "errorifexists", - table_pk: List[str] = None, - save_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``SparkHiveDataSet``. - - Args: - database: The name of the hive database. - table: The name of the table within the database. - write_mode: ``insert``, ``upsert`` or ``overwrite`` are supported. - table_pk: If performing an upsert, this identifies the primary key columns used to - resolve preexisting data. Is required for ``write_mode="upsert"``. - save_args: Optional mapping of any options, - passed to the `DataFrameWriter.saveAsTable` as kwargs. - Key example of this is `partitionBy` which allows data partitioning - on a list of column names. - Other `HiveOptions` can be found here: - https://spark.apache.org/docs/latest/sql-data-sources-hive-tables.html#specifying-storage-format-for-hive-tables - - Note: - For users leveraging the `upsert` functionality, - a `checkpoint` directory must be set, e.g. using - `spark.sparkContext.setCheckpointDir("/path/to/dir")` - or directly in the Spark conf folder. - - Raises: - DatasetError: Invalid configuration supplied - """ - _write_modes = ["append", "error", "errorifexists", "upsert", "overwrite"] - if write_mode not in _write_modes: - valid_modes = ", ".join(_write_modes) - raise DatasetError( - f"Invalid 'write_mode' provided: {write_mode}. " - f"'write_mode' must be one of: {valid_modes}" - ) - if write_mode == "upsert" and not table_pk: - raise DatasetError("'table_pk' must be set to utilise 'upsert' read mode") - - self._write_mode = write_mode - self._table_pk = table_pk or [] - self._database = database - self._table = table - self._full_table_address = f"{database}.{table}" - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - self._format = self._save_args.pop("format", None) or "hive" - self._eager_checkpoint = self._save_args.pop("eager_checkpoint", None) or True - - def _describe(self) -> Dict[str, Any]: - return { - "database": self._database, - "table": self._table, - "write_mode": self._write_mode, - "table_pk": self._table_pk, - "partition_by": self._save_args.get("partitionBy"), - "format": self._format, - } - - @staticmethod - def _get_spark() -> SparkSession: - """ - This method should only be used to get an existing SparkSession - with valid Hive configuration. - Configuration for Hive is read from hive-site.xml on the classpath. - It supports running both SQL and HiveQL commands. - Additionally, if users are leveraging the `upsert` functionality, - then a `checkpoint` directory must be set, e.g. using - `spark.sparkContext.setCheckpointDir("/path/to/dir")` - """ - _spark = SparkSession.builder.getOrCreate() - return _spark - - def _create_hive_table(self, data: DataFrame, mode: str = None): - _mode: str = mode or self._write_mode - data.write.saveAsTable( - self._full_table_address, - mode=_mode, - format=self._format, - **self._save_args, - ) - - def _load(self) -> DataFrame: - return self._get_spark().read.table(self._full_table_address) - - def _save(self, data: DataFrame) -> None: - self._validate_save(data) - if self._write_mode == "upsert": - # check if _table_pk is a subset of df columns - if not set(self._table_pk) <= set(self._load().columns): - raise DatasetError( - f"Columns {str(self._table_pk)} selected as primary key(s) not found in " - f"table {self._full_table_address}" - ) - self._upsert_save(data=data) - else: - self._create_hive_table(data=data) - - def _upsert_save(self, data: DataFrame) -> None: - if not self._exists() or self._load().rdd.isEmpty(): - self._create_hive_table(data=data, mode="overwrite") - else: - _tmp_colname = "tmp_colname" - _tmp_row = "tmp_row" - _w = Window.partitionBy(*self._table_pk).orderBy(col(_tmp_colname).desc()) - df_old = self._load().select("*", lit(1).alias(_tmp_colname)) - df_new = data.select("*", lit(2).alias(_tmp_colname)) - df_stacked = df_new.unionByName(df_old).select( - "*", row_number().over(_w).alias(_tmp_row) - ) - df_filtered = ( - df_stacked.filter(col(_tmp_row) == 1) - .drop(_tmp_colname, _tmp_row) - .checkpoint(eager=self._eager_checkpoint) - ) - self._create_hive_table(data=df_filtered, mode="overwrite") - - def _validate_save(self, data: DataFrame): - # do not validate when the table doesn't exist - # or if the `write_mode` is set to overwrite - if (not self._exists()) or self._write_mode == "overwrite": - return - hive_dtypes = set(self._load().dtypes) - data_dtypes = set(data.dtypes) - if data_dtypes != hive_dtypes: - new_cols = data_dtypes - hive_dtypes - missing_cols = hive_dtypes - data_dtypes - raise DatasetError( - f"Dataset does not match hive table schema.\n" - f"Present on insert only: {sorted(new_cols)}\n" - f"Present on schema only: {sorted(missing_cols)}" - ) - - def _exists(self) -> bool: - # noqa # noqa: protected-access - return ( - self._get_spark() - ._jsparkSession.catalog() - .tableExists(self._database, self._table) - ) - - def __getstate__(self) -> None: - raise pickle.PicklingError( - "PySpark datasets objects cannot be pickled " - "or serialised as Python objects." - ) diff --git a/kedro/extras/datasets/spark/spark_jdbc_dataset.py b/kedro/extras/datasets/spark/spark_jdbc_dataset.py deleted file mode 100644 index bacb492cbd..0000000000 --- a/kedro/extras/datasets/spark/spark_jdbc_dataset.py +++ /dev/null @@ -1,179 +0,0 @@ -"""SparkJDBCDataSet to load and save a PySpark DataFrame via JDBC.""" - -from copy import deepcopy -from typing import Any, Dict - -from pyspark.sql import DataFrame, SparkSession - -from kedro.io.core import AbstractDataset, DatasetError - -__all__ = ["SparkJDBCDataSet"] - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class SparkJDBCDataSet(AbstractDataset[DataFrame, DataFrame]): - """``SparkJDBCDataSet`` loads data from a database table accessible - via JDBC URL url and connection properties and saves the content of - a PySpark DataFrame to an external database table via JDBC. It uses - ``pyspark.sql.DataFrameReader`` and ``pyspark.sql.DataFrameWriter`` - internally, so it supports all allowed PySpark options on ``jdbc``. - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - weather: - type: spark.SparkJDBCDataSet - table: weather_table - url: jdbc:postgresql://localhost/test - credentials: db_credentials - load_args: - properties: - driver: org.postgresql.Driver - save_args: - properties: - driver: org.postgresql.Driver - - Example usage for the - `Python API `_: - :: - - >>> import pandas as pd - >>> - >>> from pyspark.sql import SparkSession - >>> - >>> spark = SparkSession.builder.getOrCreate() - >>> data = spark.createDataFrame(pd.DataFrame({'col1': [1, 2], - >>> 'col2': [4, 5], - >>> 'col3': [5, 6]})) - >>> url = 'jdbc:postgresql://localhost/test' - >>> table = 'table_a' - >>> connection_properties = {'driver': 'org.postgresql.Driver'} - >>> data_set = SparkJDBCDataSet( - >>> url=url, table=table, credentials={'user': 'scott', - >>> 'password': 'tiger'}, - >>> load_args={'properties': connection_properties}, - >>> save_args={'properties': connection_properties}) - >>> - >>> data_set.save(data) - >>> reloaded = data_set.load() - >>> - >>> assert data.toPandas().equals(reloaded.toPandas()) - - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - url: str, - table: str, - credentials: Dict[str, Any] = None, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - ) -> None: - """Creates a new ``SparkJDBCDataSet``. - - Args: - url: A JDBC URL of the form ``jdbc:subprotocol:subname``. - table: The name of the table to load or save data to. - credentials: A dictionary of JDBC database connection arguments. - Normally at least properties ``user`` and ``password`` with - their corresponding values. It updates ``properties`` - parameter in ``load_args`` and ``save_args`` in case it is - provided. - load_args: Provided to underlying PySpark ``jdbc`` function along - with the JDBC URL and the name of the table. To find all - supported arguments, see here: - https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameWriter.jdbc.html - save_args: Provided to underlying PySpark ``jdbc`` function along - with the JDBC URL and the name of the table. To find all - supported arguments, see here: - https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameWriter.jdbc.html - - Raises: - DatasetError: When either ``url`` or ``table`` is empty or - when a property is provided with a None value. - """ - - if not url: - raise DatasetError( - "'url' argument cannot be empty. Please " - "provide a JDBC URL of the form " - "'jdbc:subprotocol:subname'." - ) - - if not table: - raise DatasetError( - "'table' argument cannot be empty. Please " - "provide the name of the table to load or save " - "data to." - ) - - self._url = url - self._table = table - - # Handle default load and save arguments - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - # Update properties in load_args and save_args with credentials. - if credentials is not None: - - # Check credentials for bad inputs. - for cred_key, cred_value in credentials.items(): - if cred_value is None: - raise DatasetError( - f"Credential property '{cred_key}' cannot be None. " - f"Please provide a value." - ) - - load_properties = self._load_args.get("properties", {}) - save_properties = self._save_args.get("properties", {}) - self._load_args["properties"] = {**load_properties, **credentials} - self._save_args["properties"] = {**save_properties, **credentials} - - def _describe(self) -> Dict[str, Any]: - load_args = self._load_args - save_args = self._save_args - - # Remove user and password values from load and save properties. - if "properties" in load_args: - load_properties = load_args["properties"].copy() - load_properties.pop("user", None) - load_properties.pop("password", None) - load_args = {**load_args, "properties": load_properties} - if "properties" in save_args: - save_properties = save_args["properties"].copy() - save_properties.pop("user", None) - save_properties.pop("password", None) - save_args = {**save_args, "properties": save_properties} - - return { - "url": self._url, - "table": self._table, - "load_args": load_args, - "save_args": save_args, - } - - @staticmethod - def _get_spark(): # pragma: no cover - return SparkSession.builder.getOrCreate() - - def _load(self) -> DataFrame: - return self._get_spark().read.jdbc(self._url, self._table, **self._load_args) - - def _save(self, data: DataFrame) -> None: - return data.write.jdbc(self._url, self._table, **self._save_args) diff --git a/kedro/extras/datasets/svmlight/__init__.py b/kedro/extras/datasets/svmlight/__init__.py deleted file mode 100644 index 4b77f3dfde..0000000000 --- a/kedro/extras/datasets/svmlight/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""``AbstractDataset`` implementation to load/save data from/to a svmlight/ -libsvm sparse data file.""" -__all__ = ["SVMLightDataSet"] - -from contextlib import suppress - -with suppress(ImportError): - from .svmlight_dataset import SVMLightDataSet diff --git a/kedro/extras/datasets/svmlight/svmlight_dataset.py b/kedro/extras/datasets/svmlight/svmlight_dataset.py deleted file mode 100644 index af4a1323ad..0000000000 --- a/kedro/extras/datasets/svmlight/svmlight_dataset.py +++ /dev/null @@ -1,169 +0,0 @@ -"""``SVMLightDataSet`` loads/saves data from/to a svmlight/libsvm file using an -underlying filesystem (e.g.: local, S3, GCS). It uses sklearn functions -``dump_svmlight_file`` to save and ``load_svmlight_file`` to load a file. -""" -from copy import deepcopy -from pathlib import PurePosixPath -from typing import Any, Dict, Optional, Tuple, Union - -import fsspec -from numpy import ndarray -from scipy.sparse.csr import csr_matrix -from sklearn.datasets import dump_svmlight_file, load_svmlight_file - -from kedro.io.core import ( - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - -# Type of data input -_DI = Tuple[Union[ndarray, csr_matrix], ndarray] -# Type of data output -_DO = Tuple[csr_matrix, ndarray] - - -class SVMLightDataSet(AbstractVersionedDataset[_DI, _DO]): - """``SVMLightDataSet`` loads/saves data from/to a svmlight/libsvm file using an - underlying filesystem (e.g.: local, S3, GCS). It uses sklearn functions - ``dump_svmlight_file`` to save and ``load_svmlight_file`` to load a file. - - Data is loaded as a tuple of features and labels. Labels is NumPy array, - and features is Compressed Sparse Row matrix. - - This format is a text-based format, with one sample per line. It does - not store zero valued features hence it is suitable for sparse datasets. - - This format is used as the default format for both svmlight and the - libsvm command line programs. - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - svm_dataset: - type: svmlight.SVMLightDataSet - filepath: data/01_raw/location.svm - load_args: - zero_based: False - save_args: - zero_based: False - - cars: - type: svmlight.SVMLightDataSet - filepath: gcs://your_bucket/cars.svm - fs_args: - project: my-project - credentials: my_gcp_credentials - load_args: - zero_based: False - save_args: - zero_based: False - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.svmlight import SVMLightDataSet - >>> import numpy as np - >>> - >>> # Features and labels. - >>> data = (np.array([[0, 1], [2, 3.14159]]), np.array([7, 3])) - >>> - >>> data_set = SVMLightDataSet(filepath="test.svm") - >>> data_set.save(data) - >>> reloaded_features, reloaded_labels = data_set.load() - >>> assert (data[0] == reloaded_features).all() - >>> assert (data[1] == reloaded_labels).all() - - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Optional[Version] = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - _fs_args = deepcopy(fs_args) or {} - _fs_open_args_load = _fs_args.pop("open_args_load", {}) - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath, version) - - self._protocol = protocol - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - _fs_open_args_load.setdefault("mode", "rb") - _fs_open_args_save.setdefault("mode", "wb") - self._fs_open_args_load = _fs_open_args_load - self._fs_open_args_save = _fs_open_args_save - - def _describe(self): - return { - "filepath": self._filepath, - "protocol": self._load_args, - "save_args": self._save_args, - "version": self._version, - } - - def _load(self) -> _DO: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: - return load_svmlight_file(fs_file, **self._load_args) - - def _save(self, data: _DI) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - dump_svmlight_file(data[0], data[1], fs_file, **self._save_args) - - self._invalidate_cache() - - def _exists(self) -> bool: - try: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DatasetError: - return False - - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/tensorflow/README.md b/kedro/extras/datasets/tensorflow/README.md deleted file mode 100644 index 704d164977..0000000000 --- a/kedro/extras/datasets/tensorflow/README.md +++ /dev/null @@ -1,34 +0,0 @@ -# TensorFlowModelDataset - -``TensorflowModelDataset`` loads and saves TensorFlow models. -The underlying functionality is supported by, and passes input arguments to TensorFlow 2.X load_model and save_model methods. Only TF2 is currently supported for saving and loading, V1 requires HDF5 and serialises differently. - -#### Example use: -```python -import numpy as np -import tensorflow as tf - -from kedro.extras.datasets.tensorflow import TensorFlowModelDataset - -data_set = TensorFlowModelDataset("tf_model_dirname") - -model = tf.keras.Model() -predictions = model.predict([...]) - -data_set.save(model) -loaded_model = data_set.load() - -new_predictions = loaded_model.predict([...]) -np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6) -``` - -#### Example catalog.yml: -```yaml -example_tensorflow_data: - type: tensorflow.TensorFlowModelDataset - filepath: data/08_reporting/tf_model_dirname - load_args: - tf_device: "/CPU:0" # optional -``` - -Contributed by (Aleks Hughes)[https://github.com/w0rdsm1th]. diff --git a/kedro/extras/datasets/tensorflow/__init__.py b/kedro/extras/datasets/tensorflow/__init__.py deleted file mode 100644 index 20e1311ded..0000000000 --- a/kedro/extras/datasets/tensorflow/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Provides I/O for TensorFlow Models.""" - -__all__ = ["TensorFlowModelDataset"] - -from contextlib import suppress - -with suppress(ImportError): - from .tensorflow_model_dataset import TensorFlowModelDataset diff --git a/kedro/extras/datasets/tensorflow/tensorflow_model_dataset.py b/kedro/extras/datasets/tensorflow/tensorflow_model_dataset.py deleted file mode 100644 index ce6043b18d..0000000000 --- a/kedro/extras/datasets/tensorflow/tensorflow_model_dataset.py +++ /dev/null @@ -1,195 +0,0 @@ -"""``TensorflowModelDataset`` is a data set implementation which can save and load -TensorFlow models. -""" -import copy -import tempfile -from pathlib import PurePath, PurePosixPath -from typing import Any, Dict - -import fsspec -import tensorflow as tf - -from kedro.io.core import ( - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -TEMPORARY_H5_FILE = "tmp_tensorflow_model.h5" - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class TensorFlowModelDataset(AbstractVersionedDataset[tf.keras.Model, tf.keras.Model]): - """``TensorflowModelDataset`` loads and saves TensorFlow models. - The underlying functionality is supported by, and passes input arguments through to, - TensorFlow 2.X load_model and save_model methods. - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - tensorflow_model: - type: tensorflow.TensorFlowModelDataset - filepath: data/06_models/tensorflow_model.h5 - load_args: - compile: False - save_args: - overwrite: True - include_optimizer: False - credentials: tf_creds - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.tensorflow import TensorFlowModelDataset - >>> import tensorflow as tf - >>> import numpy as np - >>> - >>> data_set = TensorFlowModelDataset("data/06_models/tensorflow_model.h5") - >>> model = tf.keras.Model() - >>> predictions = model.predict([...]) - >>> - >>> data_set.save(model) - >>> loaded_model = data_set.load() - >>> new_predictions = loaded_model.predict([...]) - >>> np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6) - - """ - - DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] - DEFAULT_SAVE_ARGS = {"save_format": "tf"} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``TensorFlowModelDataset``. - - Args: - filepath: Filepath in POSIX format to a TensorFlow model directory prefixed with a - protocol like `s3://`. If prefix is not provided `file` protocol (local filesystem) - will be used. The prefix should be any protocol supported by ``fsspec``. - Note: `http(s)` doesn't support versioning. - load_args: TensorFlow options for loading models. - Here you can find all available arguments: - https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model - All defaults are preserved. - save_args: TensorFlow options for saving models. - Here you can find all available arguments: - https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model - All defaults are preserved, except for "save_format", which is set to "tf". - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{'token': None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``). - """ - _fs_args = copy.deepcopy(fs_args) or {} - _credentials = copy.deepcopy(credentials) or {} - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - self._tmp_prefix = "kedro_tensorflow_tmp" # temp prefix pattern - - # Handle default load and save arguments - self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - self._is_h5 = self._save_args.get("save_format") == "h5" - - def _load(self) -> tf.keras.Model: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - - with tempfile.TemporaryDirectory(prefix=self._tmp_prefix) as path: - if self._is_h5: - path = str( # noqa: PLW2901 - PurePath(path) / TEMPORARY_H5_FILE - ) # noqa: redefined-loop-name - self._fs.copy(load_path, path) - else: - self._fs.get(load_path, path, recursive=True) - - # Pass the local temporary directory/file path to keras.load_model - device_name = self._load_args.pop("tf_device", None) - if device_name: - with tf.device(device_name): - model = tf.keras.models.load_model(path, **self._load_args) - else: - model = tf.keras.models.load_model(path, **self._load_args) - return model - - def _save(self, data: tf.keras.Model) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - - with tempfile.TemporaryDirectory(prefix=self._tmp_prefix) as path: - if self._is_h5: - path = str( # noqa: PLW2901 - PurePath(path) / TEMPORARY_H5_FILE - ) # noqa: redefined-loop-name - - tf.keras.models.save_model(data, path, **self._save_args) - - # Use fsspec to take from local tempfile directory/file and - # put in ArbitraryFileSystem - if self._is_h5: - self._fs.copy(path, save_path) - else: - if self._fs.exists(save_path): - self._fs.rm(save_path, recursive=True) - self._fs.put(path, save_path, recursive=True) - - def _exists(self) -> bool: - try: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DatasetError: - return False - return self._fs.exists(load_path) - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._load_args, - "save_args": self._save_args, - "version": self._version, - } - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/text/__init__.py b/kedro/extras/datasets/text/__init__.py deleted file mode 100644 index 9ed2c37c0e..0000000000 --- a/kedro/extras/datasets/text/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""``AbstractDataset`` implementation to load/save data from/to a text file.""" - -__all__ = ["TextDataSet"] - -from contextlib import suppress - -with suppress(ImportError): - from .text_dataset import TextDataSet diff --git a/kedro/extras/datasets/text/text_dataset.py b/kedro/extras/datasets/text/text_dataset.py deleted file mode 100644 index 253ee92826..0000000000 --- a/kedro/extras/datasets/text/text_dataset.py +++ /dev/null @@ -1,144 +0,0 @@ -"""``TextDataSet`` loads/saves data from/to a text file using an underlying -filesystem (e.g.: local, S3, GCS). -""" -from copy import deepcopy -from pathlib import PurePosixPath -from typing import Any, Dict - -import fsspec - -from kedro.io.core import ( - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class TextDataSet(AbstractVersionedDataset[str, str]): - """``TextDataSet`` loads/saves data from/to a text file using an underlying - filesystem (e.g.: local, S3, GCS) - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - alice_book: - type: text.TextDataSet - filepath: data/01_raw/alice.txt - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.text import TextDataSet - >>> - >>> string_to_write = "This will go in a file." - >>> - >>> data_set = TextDataSet(filepath="test.md") - >>> data_set.save(string_to_write) - >>> reloaded = data_set.load() - >>> assert string_to_write == reloaded - - """ - - def __init__( - self, - filepath: str, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``TextDataSet`` pointing to a concrete text file - on a specific filesystem. - - Args: - filepath: Filepath in POSIX format to a text file prefixed with a protocol like `s3://`. - If prefix is not provided, `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - Note: `http(s)` doesn't support versioning. - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested keys - `open_args_load` and `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set to `r` when loading - and to `w` when saving. - """ - _fs_args = deepcopy(fs_args) or {} - _fs_open_args_load = _fs_args.pop("open_args_load", {}) - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - _fs_open_args_load.setdefault("mode", "r") - _fs_open_args_save.setdefault("mode", "w") - self._fs_open_args_load = _fs_open_args_load - self._fs_open_args_save = _fs_open_args_save - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._protocol, - "version": self._version, - } - - def _load(self) -> str: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - - with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: - return fs_file.read() - - def _save(self, data: str) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - fs_file.write(data) - - self._invalidate_cache() - - def _exists(self) -> bool: - try: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DatasetError: - return False - - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/datasets/tracking/__init__.py b/kedro/extras/datasets/tracking/__init__.py deleted file mode 100644 index 2b4d185ba8..0000000000 --- a/kedro/extras/datasets/tracking/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Dataset implementations to save data for Kedro Experiment Tracking""" - -__all__ = ["MetricsDataSet", "JSONDataSet"] - - -from contextlib import suppress - -with suppress(ImportError): - from kedro.extras.datasets.tracking.metrics_dataset import MetricsDataSet -with suppress(ImportError): - from kedro.extras.datasets.tracking.json_dataset import JSONDataSet diff --git a/kedro/extras/datasets/tracking/json_dataset.py b/kedro/extras/datasets/tracking/json_dataset.py deleted file mode 100644 index a41491492b..0000000000 --- a/kedro/extras/datasets/tracking/json_dataset.py +++ /dev/null @@ -1,49 +0,0 @@ -"""``JSONDataSet`` saves data to a JSON file using an underlying -filesystem (e.g.: local, S3, GCS). It uses native json to handle the JSON file. -The ``JSONDataSet`` is part of Kedro Experiment Tracking. The dataset is versioned by default. -""" -from typing import NoReturn - -from kedro.extras.datasets.json import JSONDataSet as JDS -from kedro.io.core import DatasetError - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class JSONDataSet(JDS): - """``JSONDataSet`` saves data to a JSON file using an underlying - filesystem (e.g.: local, S3, GCS). It uses native json to handle the JSON file. - The ``JSONDataSet`` is part of Kedro Experiment Tracking. - The dataset is write-only and it is versioned by default. - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - cars: - type: tracking.JSONDataSet - filepath: data/09_tracking/cars.json - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.tracking import JSONDataSet - >>> - >>> data = {'col1': 1, 'col2': 0.23, 'col3': 0.002} - >>> - >>> data_set = JSONDataSet(filepath="test.json") - >>> data_set.save(data) - - """ - - versioned = True - - def _load(self) -> NoReturn: - raise DatasetError(f"Loading not supported for '{self.__class__.__name__}'") diff --git a/kedro/extras/datasets/tracking/metrics_dataset.py b/kedro/extras/datasets/tracking/metrics_dataset.py deleted file mode 100644 index b2a1949702..0000000000 --- a/kedro/extras/datasets/tracking/metrics_dataset.py +++ /dev/null @@ -1,70 +0,0 @@ -"""``MetricsDataSet`` saves data to a JSON file using an underlying -filesystem (e.g.: local, S3, GCS). It uses native json to handle the JSON file. -The ``MetricsDataSet`` is part of Kedro Experiment Tracking. The dataset is versioned by default -and only takes metrics of numeric values. -""" -import json -from typing import Dict, NoReturn - -from kedro.extras.datasets.json import JSONDataSet -from kedro.io.core import DatasetError, get_filepath_str - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class MetricsDataSet(JSONDataSet): - """``MetricsDataSet`` saves data to a JSON file using an underlying - filesystem (e.g.: local, S3, GCS). It uses native json to handle the JSON file. The - ``MetricsDataSet`` is part of Kedro Experiment Tracking. The dataset is write-only, - it is versioned by default and only takes metrics of numeric values. - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - cars: - type: metrics.MetricsDataSet - filepath: data/09_tracking/cars.json - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.tracking import MetricsDataSet - >>> - >>> data = {'col1': 1, 'col2': 0.23, 'col3': 0.002} - >>> - >>> data_set = MetricsDataSet(filepath="test.json") - >>> data_set.save(data) - - """ - - versioned = True - - def _load(self) -> NoReturn: - raise DatasetError(f"Loading not supported for '{self.__class__.__name__}'") - - def _save(self, data: Dict[str, float]) -> None: - """Converts all values in the data from a ``MetricsDataSet`` to float to make sure - they are numeric values which can be displayed in Kedro Viz and then saves the dataset. - """ - try: - for key, value in data.items(): - data[key] = float(value) - except ValueError as exc: - raise DatasetError( - f"The MetricsDataSet expects only numeric values. {exc}" - ) from exc - - save_path = get_filepath_str(self._get_save_path(), self._protocol) - - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - json.dump(data, fs_file, **self._save_args) - - self._invalidate_cache() diff --git a/kedro/extras/datasets/video/__init__.py b/kedro/extras/datasets/video/__init__.py deleted file mode 100644 index f5f7af9461..0000000000 --- a/kedro/extras/datasets/video/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Dataset implementation to load/save data from/to a video file.""" - -__all__ = ["VideoDataSet"] - -from kedro.extras.datasets.video.video_dataset import VideoDataSet diff --git a/kedro/extras/datasets/video/video_dataset.py b/kedro/extras/datasets/video/video_dataset.py deleted file mode 100644 index 08e93126ec..0000000000 --- a/kedro/extras/datasets/video/video_dataset.py +++ /dev/null @@ -1,357 +0,0 @@ -"""``VideoDataSet`` loads/saves video data from an underlying -filesystem (e.g.: local, S3, GCS). It uses OpenCV VideoCapture to read -and decode videos and OpenCV VideoWriter to encode and write video. -""" -import itertools -import tempfile -from collections import abc -from copy import deepcopy -from pathlib import Path, PurePosixPath -from typing import Any, Dict, Generator, Optional, Sequence, Tuple, Union - -import cv2 -import fsspec -import numpy as np -import PIL.Image - -from kedro.io.core import AbstractDataset, get_protocol_and_path - - -class SlicedVideo: - """A representation of slices of other video types""" - - def __init__(self, video, slice_indexes): - self.video = video - self.indexes = range(*slice_indexes.indices(len(video))) - - def __getitem__(self, index: Union[int, slice]) -> PIL.Image.Image: - if isinstance(index, slice): - return SlicedVideo(self, index) - return self.video[self.indexes[index]] - - def __len__(self) -> int: - return len(self.indexes) - - def __getattr__(self, item): - return getattr(self.video, item) - - -class AbstractVideo(abc.Sequence): - """Base class for the underlying video data""" - - _n_frames = 0 - _index = 0 # Next available frame - - @property - def fourcc(self) -> str: - """Get the codec fourcc specification""" - raise NotImplementedError() - - @property - def fps(self) -> float: - """Get the video frame rate""" - raise NotImplementedError() - - @property - def size(self) -> Tuple[int, int]: - """Get the resolution of the video""" - raise NotImplementedError() - - def __len__(self) -> int: - return self._n_frames - - def __getitem__(self, index: Union[int, slice]): - """Get a frame from the video""" - raise NotImplementedError() - - -class FileVideo(AbstractVideo): - """A video object read from a file""" - - def __init__(self, filepath: str) -> None: - self._filepath = filepath - self._cap = cv2.VideoCapture(filepath) - self._n_frames = self._get_length() - - @property - def fourcc(self) -> str: - fourcc = self._cap.get(cv2.CAP_PROP_FOURCC) - return int(fourcc).to_bytes(4, "little").decode("ascii") - - @property - def fps(self) -> float: - return self._cap.get(cv2.CAP_PROP_FPS) - - @property - def size(self) -> Tuple[int, int]: - width = int(self._cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - height = int(self._cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - return width, height - - def __getitem__(self, index: Union[int, slice]): - if isinstance(index, slice): - return SlicedVideo(self, index) - - if index < 0: - index += len(self) - if index >= len(self): - raise IndexError() - - if index != self._index: - self._cap.set(cv2.CAP_PROP_POS_FRAMES, index) - self._index = index + 1 # Next frame to decode after this - ret, frame_bgr = self._cap.read() - if not ret: - raise IndexError() - - height, width = frame_bgr.shape[:2] - return PIL.Image.frombuffer( # Convert to PIL image with RGB instead of BGR - "RGB", (width, height), frame_bgr, "raw", "BGR", 0, 0 - ) - - def _get_length(self) -> int: - # OpenCV's frame count might be an approximation depending on what - # headers are available in the video file - length = int(round(self._cap.get(cv2.CAP_PROP_FRAME_COUNT))) - if length >= 0: - return length - - # Getting the frame count with OpenCV can fail on some video files, - # counting the frames would be too slow so it is better to raise an exception. - raise ValueError( - "Failed to load video since number of frames can't be inferred" - ) - - -class SequenceVideo(AbstractVideo): - """A video object read from an indexable sequence of frames""" - - def __init__( - self, frames: Sequence[PIL.Image.Image], fps: float, fourcc: str = "mp4v" - ) -> None: - self._n_frames = len(frames) - self._frames = frames - self._fourcc = fourcc - self._size = frames[0].size - self._fps = fps - - @property - def fourcc(self) -> str: - return self._fourcc - - @property - def fps(self) -> float: - return self._fps - - @property - def size(self) -> Tuple[int, int]: - return self._size - - def __getitem__(self, index: Union[int, slice]): - if isinstance(index, slice): - return SlicedVideo(self, index) - return self._frames[index] - - -class GeneratorVideo(AbstractVideo): - """A video object with frames yielded by a generator""" - - def __init__( - self, - frames: Generator[PIL.Image.Image, None, None], - length, - fps: float, - fourcc: str = "mp4v", - ) -> None: - self._n_frames = length - first = next(frames) - self._gen = itertools.chain([first], frames) - self._fourcc = fourcc - self._size = first.size - self._fps = fps - - @property - def fourcc(self) -> str: - return self._fourcc - - @property - def fps(self) -> float: - return self._fps - - @property - def size(self) -> Tuple[int, int]: - return self._size - - def __getitem__(self, index: Union[int, slice]): - raise NotImplementedError("Underlying video is a generator") - - def __next__(self): - return next(self._gen) - - def __iter__(self): - return self - - -class VideoDataSet(AbstractDataset[AbstractVideo, AbstractVideo]): - """``VideoDataSet`` loads / save video data from a given filepath as sequence - of PIL.Image.Image using OpenCV. - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - cars: - type: video.VideoDataSet - filepath: data/01_raw/cars.mp4 - - motorbikes: - type: video.VideoDataSet - filepath: s3://your_bucket/data/02_intermediate/company/motorbikes.mp4 - credentials: dev_s3 - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.video import VideoDataSet - >>> import numpy as np - >>> - >>> video = VideoDataSet(filepath='/video/file/path.mp4').load() - >>> frame = video[0] - >>> np.sum(np.asarray(frame)) - - - Example creating a video from numpy frames using Python API: - :: - - >>> from kedro.extras.datasets.video.video_dataset import VideoDataSet, SequenceVideo - >>> import numpy as np - >>> from PIL import Image - >>> - >>> frame = np.ones((640,480,3), dtype=np.uint8) * 255 - >>> imgs = [] - >>> for i in range(255): - >>> imgs.append(Image.fromarray(frame)) - >>> frame -= 1 - >>> - >>> video = VideoDataSet("my_video.mp4") - >>> video.save(SequenceVideo(imgs, fps=25)) - - - Example creating a video from numpy frames using a generator and the Python API: - :: - - >>> from kedro.extras.datasets.video.video_dataset import VideoDataSet, GeneratorVideo - >>> import numpy as np - >>> from PIL import Image - >>> - >>> def gen(): - >>> frame = np.ones((640,480,3), dtype=np.uint8) * 255 - >>> for i in range(255): - >>> yield Image.fromarray(frame) - >>> frame -= 1 - >>> - >>> video = VideoDataSet("my_video.mp4") - >>> video.save(GeneratorVideo(gen(), fps=25, length=None)) - - """ - - def __init__( - self, - filepath: str, - fourcc: Optional[str] = "mp4v", - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of VideoDataSet to load / save video data for given filepath. - - Args: - filepath: The location of the video file to load / save data. - fourcc: The codec to use when writing video, note that depending on how opencv is - installed there might be more or less codecs avaiable. If set to None, the - fourcc from the video object will be used. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``). - """ - # parse the path and protocol (e.g. file, http, s3, etc.) - protocol, path = get_protocol_and_path(filepath) - self._protocol = protocol - self._filepath = PurePosixPath(path) - self._fourcc = fourcc - _fs_args = deepcopy(fs_args) or {} - _credentials = deepcopy(credentials) or {} - self._storage_options = {**_credentials, **_fs_args} - self._fs = fsspec.filesystem(self._protocol, **self._storage_options) - - def _load(self) -> AbstractVideo: - """Loads data from the video file. - - Returns: - Data from the video file as a AbstractVideo object - """ - with fsspec.open( - f"filecache::{self._protocol}://{self._filepath}", - mode="rb", - **{self._protocol: self._storage_options}, - ) as fs_file: - return FileVideo(fs_file.name) - - def _save(self, data: AbstractVideo) -> None: - """Saves video data to the specified filepath.""" - if self._protocol == "file": - # Write directly to the local file destination - self._write_to_filepath(data, str(self._filepath)) - else: - # VideoWriter can't write to an open file object, instead write to a - # local tmpfile and then copy that to the destination with fsspec. - # Note that the VideoWriter fails to write to the file on Windows if - # the file is already open, thus we can't use NamedTemporaryFile. - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_file = Path(tmp_dir) / self._filepath.name - self._write_to_filepath(data, str(tmp_file)) - with fsspec.open( - f"{self._protocol}://{self._filepath}", - "wb", - **self._storage_options, - ) as f_target: - with tmp_file.open("r+b") as f_tmp: - f_target.write(f_tmp.read()) - - def _write_to_filepath(self, video: AbstractVideo, filepath: str) -> None: - # TODO: This uses the codec specified in the VideoDataSet if it is not None, this is due - # to compatibility issues since e.g. h264 coded is licensed and is thus not included in - # opencv if installed from a binary distribution. Since a h264 video can be read, but not - # written, it would be error prone to use the videos fourcc code. Further, an issue is - # that the video object does not know what container format will be used since that is - # selected by the suffix in the file name of the VideoDataSet. Some combinations of codec - # and container format might not work or will have bad support. - fourcc = self._fourcc or video.fourcc - - writer = cv2.VideoWriter( - filepath, cv2.VideoWriter_fourcc(*fourcc), video.fps, video.size - ) - if not writer.isOpened(): - raise ValueError( - "Failed to open video writer with params: " - + f"fourcc={fourcc} fps={video.fps} size={video.size[0]}x{video.size[1]} " - + f"path={filepath}" - ) - try: - for frame in iter(video): - writer.write( # PIL images are RGB, opencv expects BGR - np.asarray(frame)[:, :, ::-1] - ) - finally: - writer.release() - - def _describe(self) -> Dict[str, Any]: - return {"filepath": self._filepath, "protocol": self._protocol} - - def _exists(self) -> bool: - return self._fs.exists(self._filepath) diff --git a/kedro/extras/datasets/yaml/__init__.py b/kedro/extras/datasets/yaml/__init__.py deleted file mode 100644 index 07abbaf4a5..0000000000 --- a/kedro/extras/datasets/yaml/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""``AbstractDataset`` implementation to load/save data from/to a YAML file.""" - -__all__ = ["YAMLDataSet"] - -from contextlib import suppress - -with suppress(ImportError): - from .yaml_dataset import YAMLDataSet diff --git a/kedro/extras/datasets/yaml/yaml_dataset.py b/kedro/extras/datasets/yaml/yaml_dataset.py deleted file mode 100644 index a98e76314e..0000000000 --- a/kedro/extras/datasets/yaml/yaml_dataset.py +++ /dev/null @@ -1,156 +0,0 @@ -"""``YAMLDataSet`` loads/saves data from/to a YAML file using an underlying -filesystem (e.g.: local, S3, GCS). It uses PyYAML to handle the YAML file. -""" -from copy import deepcopy -from pathlib import PurePosixPath -from typing import Any, Dict - -import fsspec -import yaml - -from kedro.io.core import ( - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) - -# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. -# Any contribution to datasets should be made in kedro-datasets -# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) - - -class YAMLDataSet(AbstractVersionedDataset[Dict, Dict]): - """``YAMLDataSet`` loads/saves data from/to a YAML file using an underlying - filesystem (e.g.: local, S3, GCS). It uses PyYAML to handle the YAML file. - - Example usage for the - `YAML API `_: - - - .. code-block:: yaml - - cars: - type: yaml.YAMLDataSet - filepath: cars.yaml - - Example usage for the - `Python API `_: - :: - - >>> from kedro.extras.datasets.yaml import YAMLDataSet - >>> - >>> data = {'col1': [1, 2], 'col2': [4, 5], 'col3': [5, 6]} - >>> - >>> data_set = YAMLDataSet(filepath="test.yaml") - >>> data_set.save(data) - >>> reloaded = data_set.load() - >>> assert data == reloaded - - """ - - DEFAULT_SAVE_ARGS = {"default_flow_style": False} # type: Dict[str, Any] - - def __init__( # noqa: too-many-arguments - self, - filepath: str, - save_args: Dict[str, Any] = None, - version: Version = None, - credentials: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, - ) -> None: - """Creates a new instance of ``YAMLDataSet`` pointing to a concrete YAML file - on a specific filesystem. - - Args: - filepath: Filepath in POSIX format to a YAML file prefixed with a protocol like `s3://`. - If prefix is not provided, `file` protocol (local filesystem) will be used. - The prefix should be any protocol supported by ``fsspec``. - Note: `http(s)` doesn't support versioning. - save_args: PyYAML options for saving YAML files (arguments passed - into ```yaml.dump``). Here you can find all available arguments: - https://pyyaml.org/wiki/PyYAMLDocumentation - All defaults are preserved, but "default_flow_style", which is set to False. - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials required to get access to the underlying filesystem. - E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. - fs_args: Extra arguments to pass into underlying filesystem class constructor - (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as - to pass to the filesystem's `open` method through nested keys - `open_args_load` and `open_args_save`. - Here you can find all available arguments for `open`: - https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open - All defaults are preserved, except `mode`, which is set to `r` when loading - and to `w` when saving. - """ - _fs_args = deepcopy(fs_args) or {} - _fs_open_args_load = _fs_args.pop("open_args_load", {}) - _fs_open_args_save = _fs_args.pop("open_args_save", {}) - _credentials = deepcopy(credentials) or {} - - protocol, path = get_protocol_and_path(filepath, version) - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._protocol = protocol - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - super().__init__( - filepath=PurePosixPath(path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - # Handle default save arguments - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - _fs_open_args_save.setdefault("mode", "w") - self._fs_open_args_load = _fs_open_args_load - self._fs_open_args_save = _fs_open_args_save - - def _describe(self) -> Dict[str, Any]: - return { - "filepath": self._filepath, - "protocol": self._protocol, - "save_args": self._save_args, - "version": self._version, - } - - def _load(self) -> Dict: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - - with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: - return yaml.safe_load(fs_file) - - def _save(self, data: Dict) -> None: - save_path = get_filepath_str(self._get_save_path(), self._protocol) - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - yaml.dump(data, fs_file, **self._save_args) - - self._invalidate_cache() - - def _exists(self) -> bool: - try: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DatasetError: - return False - - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) diff --git a/kedro/extras/extensions/__init__.py b/kedro/extras/extensions/__init__.py deleted file mode 100644 index 8cad7f44a1..0000000000 --- a/kedro/extras/extensions/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -This module contains an IPython extension. -""" diff --git a/kedro/extras/extensions/ipython.py b/kedro/extras/extensions/ipython.py deleted file mode 100644 index ee700b571e..0000000000 --- a/kedro/extras/extensions/ipython.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -This file and directory exists purely for backwards compatibility of the following: -%load_ext kedro.extras.extensions.ipython -from kedro.extras.extensions.ipython import reload_kedro - -Any modifications to the IPython extension should now be made in kedro/ipython/. -The Kedro IPython extension should always be loaded as %load_ext kedro.ipython. -Line magics such as reload_kedro should always be called as line magics rather than -importing the underlying Python functions. -""" -import warnings - -from ...ipython import ( # noqa # noqa: unused-import - load_ipython_extension, - reload_kedro, -) - -warnings.warn( - "kedro.extras.extensions.ipython should be accessed only using the alias " - "kedro.ipython. The unaliased name will be removed in Kedro 0.19.0.", - DeprecationWarning, -) diff --git a/kedro/extras/logging/__init__.py b/kedro/extras/logging/__init__.py deleted file mode 100644 index 2892db7846..0000000000 --- a/kedro/extras/logging/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -""" -This module contains a logging handler class which produces coloured logs. -""" -import warnings - -from .color_logger import ColorHandler - -__all__ = ["ColorHandler"] - -warnings.simplefilter("default", DeprecationWarning) - -warnings.warn( - "Support for ColorHandler will be removed in Kedro 0.19.0.", - DeprecationWarning, -) diff --git a/kedro/extras/logging/color_logger.py b/kedro/extras/logging/color_logger.py deleted file mode 100644 index e468b2b5e8..0000000000 --- a/kedro/extras/logging/color_logger.py +++ /dev/null @@ -1,95 +0,0 @@ -"""A logging handler class which produces coloured logs.""" - - -import logging - -import click - - -class ColorHandler(logging.StreamHandler): - """A color log handler. - - You can use this handler by incorporating the example below into your - logging configuration: - - ``conf/project/logging.yml``: - :: - - formatters: - simple: - format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - - handlers: - console: - class: kedro.extras.logging.ColorHandler - level: INFO - formatter: simple - stream: ext://sys.stdout - # defining colors is optional - colors: - debug: white - info: magenta - warning: yellow - - root: - level: INFO - handlers: [console] - - The ``colors`` parameter is optional, and you can use any ANSI color. - - * Black - * Red - * Green - * Yellow - * Blue - * Magenta - * Cyan - * White - - The default colors are: - - * debug: magenta - * info: cyan - * warning: yellow - * error: red - * critical: red - """ - - def __init__(self, stream=None, colors=None): - logging.StreamHandler.__init__(self, stream) - colors = colors or {} - self.colors = { - "critical": colors.get("critical", "red"), - "error": colors.get("error", "red"), - "warning": colors.get("warning", "yellow"), - "info": colors.get("info", "cyan"), - "debug": colors.get("debug", "magenta"), - } - - def _get_color(self, level): - if level >= logging.CRITICAL: - return self.colors["critical"] # pragma: no cover - if level >= logging.ERROR: - return self.colors["error"] # pragma: no cover - if level >= logging.WARNING: - return self.colors["warning"] # pragma: no cover - if level >= logging.INFO: - return self.colors["info"] - if level >= logging.DEBUG: # pragma: no cover - return self.colors["debug"] # pragma: no cover - - return None # pragma: no cover - - def format(self, record: logging.LogRecord) -> str: - """The handler formatter. - - Args: - record: The record to format. - - Returns: - The record formatted as a string. - - """ - text = logging.StreamHandler.format(self, record) - color = self._get_color(record.levelno) - return click.style(text, color) diff --git a/kedro/framework/cli/catalog.py b/kedro/framework/cli/catalog.py index 24816a9492..36e0ed17c3 100644 --- a/kedro/framework/cli/catalog.py +++ b/kedro/framework/cli/catalog.py @@ -15,12 +15,7 @@ def _create_session(package_name: str, **kwargs): kwargs.setdefault("save_on_close", False) - try: - return KedroSession.create(package_name, **kwargs) - except Exception as exc: - raise KedroCliError( - f"Unable to instantiate Kedro session.\nError: {exc}" - ) from exc + return KedroSession.create(package_name, **kwargs) # noqa: missing-function-docstring @@ -57,9 +52,14 @@ def list_datasets(metadata: ProjectMetadata, pipeline, env): session = _create_session(metadata.package_name, env=env) context = session.load_context() - data_catalog = context.catalog - datasets_meta = data_catalog._data_sets - catalog_ds = set(data_catalog.list()) + try: + data_catalog = context.catalog + datasets_meta = data_catalog._data_sets + catalog_ds = set(data_catalog.list()) + except Exception as exc: + raise KedroCliError( + f"Unable to instantiate Kedro Catalog.\nError: {exc}" + ) from exc target_pipelines = pipeline or pipelines.keys() diff --git a/kedro/framework/cli/cli.py b/kedro/framework/cli/cli.py index 03c9743500..304fb6b4bc 100644 --- a/kedro/framework/cli/cli.py +++ b/kedro/framework/cli/cli.py @@ -4,7 +4,6 @@ """ import importlib import sys -import webbrowser from collections import defaultdict from pathlib import Path from typing import Sequence @@ -81,19 +80,6 @@ def info(): click.echo("No plugins installed") -@cli.command(short_help="See the kedro API docs and introductory tutorial.") -def docs(): - """Display the online API docs and introductory tutorial in the browser. (DEPRECATED)""" - deprecation_message = ( - "DeprecationWarning: Command 'kedro docs' is deprecated and " - "will not be available from Kedro 0.19.0." - ) - click.secho(deprecation_message, fg="red") - index_path = f"https://kedro.readthedocs.io/en/{version}" - click.echo(f"Opening {index_path}") - webbrowser.open(index_path) - - def _init_plugins() -> None: init_hooks = load_entry_points("init") for init_hook in init_hooks: diff --git a/kedro/framework/cli/jupyter.py b/kedro/framework/cli/jupyter.py index e7cfbc166e..d2facef34b 100644 --- a/kedro/framework/cli/jupyter.py +++ b/kedro/framework/cli/jupyter.py @@ -6,20 +6,13 @@ import json import os import shutil -import sys -from collections import Counter -from glob import iglob from pathlib import Path -from typing import Any -from warnings import warn import click -from click import secho from kedro.framework.cli.utils import ( KedroCliError, _check_module_importable, - command_with_verbosity, env_option, forward_command, python_call, @@ -190,119 +183,3 @@ def _create_kernel(kernel_name: str, display_name: str) -> str: f"Cannot setup kedro kernel for Jupyter.\nError: {exc}" ) from exc return kernel_path - - -@command_with_verbosity(jupyter, "convert") -@click.option("--all", "-a", "all_flag", is_flag=True, help=CONVERT_ALL_HELP) -@click.option("-y", "overwrite_flag", is_flag=True, help=OVERWRITE_HELP) -@click.argument( - "filepath", - type=click.Path(exists=True, dir_okay=False, resolve_path=True), - required=False, - nargs=-1, -) -@env_option -@click.pass_obj # this will pass the metadata as first argument -def convert_notebook( - metadata: ProjectMetadata, all_flag, overwrite_flag, filepath, env, **kwargs -): # noqa: unused-argument, too-many-locals - """Convert selected or all notebooks found in a Kedro project - to Kedro code, by exporting code from the appropriately-tagged cells: - Cells tagged as `node` will be copied over to a Python file matching - the name of the notebook, under `//nodes`. - *Note*: Make sure your notebooks have unique names! - FILEPATH: Path(s) to exact notebook file(s) to be converted. Both - relative and absolute paths are accepted. - Should not be provided if --all flag is already present. (DEPRECATED) - """ - - deprecation_message = ( - "DeprecationWarning: Command 'kedro jupyter convert' is deprecated and " - "will not be available from Kedro 0.19.0." - ) - click.secho(deprecation_message, fg="red") - - project_path = metadata.project_path - source_path = metadata.source_dir - package_name = metadata.package_name - - if not filepath and not all_flag: - secho( - "Please specify a notebook filepath " - "or add '--all' to convert all notebooks." - ) - sys.exit(1) - - if all_flag: - # pathlib glob does not ignore hidden directories, - # whereas Python glob does, which is more useful in - # ensuring checkpoints will not be included - pattern = project_path / "**" / "*.ipynb" - notebooks = sorted(Path(p) for p in iglob(str(pattern), recursive=True)) - else: - notebooks = [Path(f) for f in filepath] - - counter = Counter(n.stem for n in notebooks) - non_unique_names = [name for name, counts in counter.items() if counts > 1] - if non_unique_names: - names = ", ".join(non_unique_names) - raise KedroCliError( - f"Found non-unique notebook names! Please rename the following: {names}" - ) - - output_dir = source_path / package_name / "nodes" - if not output_dir.is_dir(): - output_dir.mkdir() - (output_dir / "__init__.py").touch() - - for notebook in notebooks: - secho(f"Converting notebook '{notebook}'...") - output_path = output_dir / f"{notebook.stem}.py" - - if output_path.is_file(): - overwrite = overwrite_flag or click.confirm( - f"Output file {output_path} already exists. Overwrite?", default=False - ) - if overwrite: - _export_nodes(notebook, output_path) - else: - _export_nodes(notebook, output_path) - - secho("Done!", color="green") # type: ignore - - -def _export_nodes(filepath: Path, output_path: Path) -> None: - """Copy code from Jupyter cells into nodes in src//nodes/, - under filename with same name as notebook. - - Args: - filepath: Path to Jupyter notebook file - output_path: Path where notebook cells' source code will be exported - Raises: - KedroCliError: When provided a filepath that cannot be read as a - Jupyer notebook and loaded into json format. - """ - try: - content = json.loads(filepath.read_text()) - except json.JSONDecodeError as exc: - raise KedroCliError( - f"Provided filepath is not a Jupyter notebook: {filepath}" - ) from exc - cells = [ - cell - for cell in content["cells"] - if cell["cell_type"] == "code" and "node" in cell["metadata"].get("tags", {}) - ] - - if cells: - output_path.write_text("") - for cell in cells: - _append_source_code(cell, output_path) - else: - warn(f"Skipping notebook '{filepath}' - no nodes to export.") - - -def _append_source_code(cell: dict[str, Any], path: Path) -> None: - source_code = "".join(cell["source"]).strip() + "\n" - with path.open(mode="a") as file_: - file_.write(source_code) diff --git a/kedro/framework/cli/micropkg.py b/kedro/framework/cli/micropkg.py index f42ea0edbf..d063659833 100644 --- a/kedro/framework/cli/micropkg.py +++ b/kedro/framework/cli/micropkg.py @@ -13,7 +13,6 @@ from typing import Any, Iterable, Iterator, List, Tuple, Union import click -from build.util import project_wheel_metadata from packaging.requirements import InvalidRequirement, Requirement from packaging.utils import canonicalize_name from rope.base.project import Project @@ -22,6 +21,7 @@ from rope.refactor.rename import Rename from setuptools.discovery import FlatLayoutPackageFinder +from build.util import project_wheel_metadata from kedro.framework.cli.pipeline import ( _assert_pkg_name_ok, _check_pipeline_name, @@ -962,8 +962,8 @@ def _append_package_reqs( file.write(sep.join(sorted_reqs)) click.secho( - "Use 'kedro build-reqs' to compile and 'pip install -r src/requirements.lock' to install " - "the updated list of requirements." + "Use 'pip-compile src/requirements.txt --output-file src/requirements.lock' to compile " + "and 'pip install -r src/requirements.lock' to install the updated list of requirements." ) diff --git a/kedro/framework/cli/project.py b/kedro/framework/cli/project.py index f3cf141dfa..e9286e71ad 100644 --- a/kedro/framework/cli/project.py +++ b/kedro/framework/cli/project.py @@ -1,16 +1,12 @@ """A collection of CLI commands for working with Kedro project.""" import os -import shutil -import subprocess import sys -import webbrowser from pathlib import Path import click from kedro.framework.cli.utils import ( - KedroCliError, _check_module_importable, _config_file_callback, _deprecate_options, @@ -19,10 +15,8 @@ _split_load_versions, _split_params, call, - command_with_verbosity, env_option, forward_command, - python_call, split_node_names, split_string, ) @@ -74,63 +68,6 @@ def project_group(): # pragma: no cover pass -@forward_command(project_group, forward_help=True) -@click.pass_obj # this will pass the metadata as first argument -def test(metadata: ProjectMetadata, args, **kwargs): # noqa: ument - """Run the test suite. (DEPRECATED)""" - deprecation_message = ( - "DeprecationWarning: Command 'kedro test' is deprecated and " - "will not be available from Kedro 0.19.0. " - "Use the command 'pytest' instead. " - ) - click.secho(deprecation_message, fg="red") - - try: - _check_module_importable("pytest") - except KedroCliError as exc: - source_path = metadata.source_dir - raise KedroCliError( - NO_DEPENDENCY_MESSAGE.format(module="pytest", src=str(source_path)) - ) from exc - python_call("pytest", args) - - -@command_with_verbosity(project_group) -@click.option("-c", "--check-only", is_flag=True, help=LINT_CHECK_ONLY_HELP) -@click.argument("files", type=click.Path(exists=True), nargs=-1) -@click.pass_obj # this will pass the metadata as first argument -def lint( - metadata: ProjectMetadata, files, check_only, **kwargs -): # noqa: unused-argument - """Run flake8, isort and black. (DEPRECATED)""" - deprecation_message = ( - "DeprecationWarning: Command 'kedro lint' is deprecated and " - "will not be available from Kedro 0.19.0." - ) - click.secho(deprecation_message, fg="red") - - source_path = metadata.source_dir - package_name = metadata.package_name - files = files or (str(source_path / "tests"), str(source_path / package_name)) - - if "PYTHONPATH" not in os.environ: - # isort needs the source path to be in the 'PYTHONPATH' environment - # variable to treat it as a first-party import location - os.environ["PYTHONPATH"] = str(source_path) # pragma: no cover - - for module_name in ("flake8", "isort", "black"): - try: - _check_module_importable(module_name) - except KedroCliError as exc: - raise KedroCliError( - NO_DEPENDENCY_MESSAGE.format(module=module_name, src=str(source_path)) - ) from exc - - python_call("black", ("--check",) + files if check_only else files) - python_call("flake8", files) - python_call("isort", ("--check",) + files if check_only else files) - - @forward_command(project_group, forward_help=True) @env_option @click.pass_obj # this will pass the metadata as first argument @@ -177,145 +114,6 @@ def package(metadata: ProjectMetadata): ) -@project_group.command("build-docs") -@click.option( - "--open", - "-o", - "open_docs", - is_flag=True, - multiple=False, - default=False, - help=OPEN_ARG_HELP, -) -@click.pass_obj # this will pass the metadata as first argument -def build_docs(metadata: ProjectMetadata, open_docs): - """Build the project documentation. (DEPRECATED)""" - deprecation_message = ( - "DeprecationWarning: Command 'kedro build-docs' is deprecated and " - "will not be available from Kedro 0.19.0." - ) - click.secho(deprecation_message, fg="red") - - source_path = metadata.source_dir - package_name = metadata.package_name - - python_call("pip", ["install", str(source_path / "[docs]")]) - python_call("pip", ["install", "-r", str(source_path / "requirements.txt")]) - python_call("ipykernel", ["install", "--user", f"--name={package_name}"]) - shutil.rmtree("docs/build", ignore_errors=True) - call( - [ - "sphinx-apidoc", - "--module-first", - "-o", - "docs/source", - str(source_path / package_name), - ] - ) - call(["sphinx-build", "-M", "html", "docs/source", "docs/build", "-a"]) - if open_docs: - docs_page = (Path.cwd() / "docs" / "build" / "html" / "index.html").as_uri() - click.secho(f"Opening {docs_page}") - webbrowser.open(docs_page) - - -@forward_command(project_group, name="build-reqs") -@click.option( - "--input-file", - "input_file", - type=click.Path(exists=True, dir_okay=False, resolve_path=True), - multiple=False, - help=INPUT_FILE_HELP, -) -@click.option( - "--output-file", - "output_file", - multiple=False, - help=OUTPUT_FILE_HELP, -) -@click.pass_obj # this will pass the metadata as first argument -def build_reqs( - metadata: ProjectMetadata, input_file, output_file, args, **kwargs -): # noqa: unused-argument - """Run `pip-compile` on src/requirements.txt or the user defined input file and save - the compiled requirements to src/requirements.lock or the user defined output file. - (DEPRECATED) - """ - deprecation_message = ( - "DeprecationWarning: Command 'kedro build-reqs' is deprecated and " - "will not be available from Kedro 0.19.0." - ) - click.secho(deprecation_message, fg="red") - - source_path = metadata.source_dir - input_file = Path(input_file or source_path / "requirements.txt") - output_file = Path(output_file or source_path / "requirements.lock") - - if input_file.is_file(): - python_call( - "piptools", - [ - "compile", - *args, - str(input_file), - "--output-file", - str(output_file), - ], - ) - - else: - raise FileNotFoundError( - f"File '{input_file}' not found in the project. " - "Please specify another input or create the file and try again." - ) - - click.secho( - f"Requirements built! Please update {input_file.name} " - "if you'd like to make a change in your project's dependencies, " - f"and re-run build-reqs to generate the new {output_file.name}.", - fg="green", - ) - - -@command_with_verbosity(project_group, "activate-nbstripout") -@click.pass_obj # this will pass the metadata as first argument -def activate_nbstripout(metadata: ProjectMetadata, **kwargs): # noqa: unused-argument - """Install the nbstripout git hook to automatically clean notebooks. (DEPRECATED)""" - deprecation_message = ( - "DeprecationWarning: Command 'kedro activate-nbstripout' is deprecated and " - "will not be available from Kedro 0.19.0." - ) - click.secho(deprecation_message, fg="red") - - source_path = metadata.source_dir - click.secho( - ( - "Notebook output cells will be automatically cleared before committing" - " to git." - ), - fg="yellow", - ) - - try: - _check_module_importable("nbstripout") - except KedroCliError as exc: - raise KedroCliError( - NO_DEPENDENCY_MESSAGE.format(module="nbstripout", src=str(source_path)) - ) from exc - - try: - res = subprocess.run( # noqa: subprocess-run-check - ["git", "rev-parse", "--git-dir"], - capture_output=True, - ) - if res.returncode: - raise KedroCliError("Not a git repository. Run 'git init' first.") - except FileNotFoundError as exc: - raise KedroCliError("Git executable not found. Install Git first.") from exc - - call(["nbstripout", "--install"]) - - @project_group.command() @click.option( "--from-inputs", diff --git a/kedro/framework/context/context.py b/kedro/framework/context/context.py index 0ba116255a..003ff696ae 100644 --- a/kedro/framework/context/context.py +++ b/kedro/framework/context/context.py @@ -8,6 +8,7 @@ from urllib.parse import urlparse from warnings import warn +from attrs import field, frozen from pluggy import PluginManager from kedro.config import ConfigLoader, MissingConfigException @@ -153,65 +154,42 @@ def _update_nested_dict(old_dict: dict[Any, Any], new_dict: dict[Any, Any]) -> N old_dict[key] = value +def _expand_full_path(project_path: str | Path) -> Path: + return Path(project_path).expanduser().resolve() + + +@frozen class KedroContext: """``KedroContext`` is the base class which holds the configuration and Kedro's main functionality. """ - def __init__( # noqa: too-many-arguments - self, - package_name: str, - project_path: Path | str, - config_loader: ConfigLoader, - hook_manager: PluginManager, - env: str = None, - extra_params: dict[str, Any] = None, - ): - """Create a context object by providing the root of a Kedro project and - the environment configuration subfolders - (see ``kedro.config.ConfigLoader``) - - Raises: - KedroContextError: If there is a mismatch - between Kedro project version and package version. - - Args: - package_name: Package name for the Kedro project the context is - created for. - project_path: Project path to define the context for. - hook_manager: The ``PluginManager`` to activate hooks, supplied by the session. - env: Optional argument for configuration default environment to be used - for running the pipeline. If not specified, it defaults to "local". - extra_params: Optional dictionary containing extra project parameters. - If specified, will update (and therefore take precedence over) - the parameters retrieved from the project configuration. - """ - self._project_path = Path(project_path).expanduser().resolve() - self._package_name = package_name - self._config_loader = config_loader - self._env = env - self._extra_params = deepcopy(extra_params) - self._hook_manager = hook_manager - - @property # type: ignore - def env(self) -> str | None: - """Property for the current Kedro environment. - - Returns: - Name of the current Kedro environment. - - """ - return self._env + _package_name: str + project_path: Path = field(converter=_expand_full_path) + config_loader: ConfigLoader + _hook_manager: PluginManager + env: str | None = None + _extra_params: dict[str, Any] | None = field(default=None, converter=deepcopy) - @property - def project_path(self) -> Path: - """Read-only property containing Kedro's root project directory. + """Create a context object by providing the root of a Kedro project and + the environment configuration subfolders (see ``kedro.config.ConfigLoader``) - Returns: - Project directory. + Raises: + KedroContextError: If there is a mismatch + between Kedro project version and package version. - """ - return self._project_path + Args: + package_name: Package name for the Kedro project the context is + created for. + project_path: Project path to define the context for. + config_loader: Kedro's ``ConfigLoader`` for loading the configuration files. + hook_manager: The ``PluginManager`` to activate hooks, supplied by the session. + env: Optional argument for configuration default environment to be used + for running the pipeline. If not specified, it defaults to "local". + extra_params: Optional dictionary containing extra project parameters. + If specified, will update (and therefore take precedence over) + the parameters retrieved from the project configuration. + """ @property def catalog(self) -> DataCatalog: @@ -241,17 +219,6 @@ def params(self) -> dict[str, Any]: _update_nested_dict(params, self._extra_params or {}) return params - @property - def config_loader(self): - """Read-only property referring to Kedro's ``ConfigLoader`` for this - context. - Returns: - Instance of `ConfigLoader`. - Raises: - KedroContextError: Incorrect ``ConfigLoader`` registered for the project. - """ - return self._config_loader - def _get_catalog( self, save_version: str = None, diff --git a/kedro/framework/project/__init__.py b/kedro/framework/project/__init__.py index f266da430c..ea7369cadf 100644 --- a/kedro/framework/project/__init__.py +++ b/kedro/framework/project/__init__.py @@ -226,6 +226,14 @@ def configure(self, logging_config: dict[str, Any]) -> None: logging.config.dictConfig(logging_config) self.data = logging_config + def set_project_logging(self, package_name: str): + """Add the project level logging to the loggers upon provision of a package name. + Checks if project logger already exists to prevent overwriting, if none exists + it defaults to setting project logs at INFO level.""" + if package_name not in self.data["loggers"]: + self.data["loggers"][package_name] = {"level": "INFO"} + self.configure(self.data) + PACKAGE_NAME = None LOGGING = _ProjectLogging() @@ -252,6 +260,9 @@ def configure_project(package_name: str): global PACKAGE_NAME # noqa: PLW0603 PACKAGE_NAME = package_name + if PACKAGE_NAME: + LOGGING.set_project_logging(PACKAGE_NAME) + def configure_logging(logging_config: dict[str, Any]) -> None: """Configure logging according to ``logging_config`` dictionary.""" diff --git a/kedro/framework/session/session.py b/kedro/framework/session/session.py index 1a997d6d36..dbe2fc2b29 100644 --- a/kedro/framework/session/session.py +++ b/kedro/framework/session/session.py @@ -16,13 +16,11 @@ import click from kedro import __version__ as kedro_version -from kedro.config import ConfigLoader, MissingConfigException, TemplatedConfigLoader +from kedro.config import ConfigLoader, TemplatedConfigLoader from kedro.framework.context import KedroContext -from kedro.framework.context.context import _convert_paths_to_absolute_posix from kedro.framework.hooks import _create_hook_manager from kedro.framework.hooks.manager import _register_hooks, _register_hooks_entry_points from kedro.framework.project import ( - configure_logging, pipelines, settings, validate_settings, @@ -188,36 +186,11 @@ def create( # noqa: too-many-arguments "Unable to get username. Full exception: %s", exc ) - session._store.update(session_data) - - # We need ConfigLoader and env to setup logging correctly - session._setup_logging() session_data.update(**_describe_git(session._project_path)) session._store.update(session_data) return session - def _get_logging_config(self) -> dict[str, Any]: - logging_config = self._get_config_loader()["logging"] - # turn relative paths in logging config into absolute path - # before initialising loggers - logging_config = _convert_paths_to_absolute_posix( - project_path=self._project_path, conf_dictionary=logging_config - ) - return logging_config - - def _setup_logging(self) -> None: - """Register logging specified in logging directory.""" - try: - logging_config = self._get_logging_config() - except MissingConfigException: - self._logger.debug( - "No project logging configuration loaded; " - "Kedro's default logging configuration will be used." - ) - else: - configure_logging(logging_config) - def _init_store(self) -> BaseSessionStore: store_class = settings.SESSION_STORE_CLASS classpath = f"{store_class.__module__}.{store_class.__qualname__}" diff --git a/kedro/io/core.py b/kedro/io/core.py index 6a097d7058..e031678b5c 100644 --- a/kedro/io/core.py +++ b/kedro/io/core.py @@ -353,10 +353,7 @@ class Version(namedtuple("Version", ["load", "save"])): "intermediate data sets where possible to avoid this warning." ) -# `kedro_datasets` is probed before `kedro.extras.datasets`, -# hence the DeprecationWarning will not be shown -# if the dataset is available in the former -_DEFAULT_PACKAGES = ["kedro.io.", "kedro_datasets.", "kedro.extras.datasets.", ""] +_DEFAULT_PACKAGES = ["kedro.io.", "kedro_datasets.", ""] def parse_dataset_definition( diff --git a/kedro/io/data_catalog.py b/kedro/io/data_catalog.py index 031abb5b51..a595bfdf87 100644 --- a/kedro/io/data_catalog.py +++ b/kedro/io/data_catalog.py @@ -176,7 +176,7 @@ def __init__( # noqa: too-many-arguments Example: :: - >>> from kedro.extras.datasets.pandas import CSVDataSet + >>> from kedro_datasets.pandas import CSVDataSet >>> >>> cars = CSVDataSet(filepath="cars.csv", >>> load_args=None, @@ -486,7 +486,7 @@ def load(self, name: str, version: str = None) -> Any: :: >>> from kedro.io import DataCatalog - >>> from kedro.extras.datasets.pandas import CSVDataSet + >>> from kedro_datasets.pandas import CSVDataSet >>> >>> cars = CSVDataSet(filepath="cars.csv", >>> load_args=None, @@ -499,7 +499,10 @@ def load(self, name: str, version: str = None) -> Any: dataset = self._get_dataset(name, version=load_version) self._logger.info( - "Loading data from '%s' (%s)...", name, type(dataset).__name__ + "Loading data from [dark_orange]%s[/dark_orange] (%s)...", + name, + type(dataset).__name__, + extra={"markup": True}, ) result = dataset.load() @@ -523,7 +526,7 @@ def save(self, name: str, data: Any) -> None: >>> import pandas as pd >>> - >>> from kedro.extras.datasets.pandas import CSVDataSet + >>> from kedro_datasets.pandas import CSVDataSet >>> >>> cars = CSVDataSet(filepath="cars.csv", >>> load_args=None, @@ -537,7 +540,12 @@ def save(self, name: str, data: Any) -> None: """ dataset = self._get_dataset(name) - self._logger.info("Saving data to '%s' (%s)...", name, type(dataset).__name__) + self._logger.info( + "Saving data to [dark_orange]%s[/dark_orange] (%s)...", + name, + type(dataset).__name__, + extra={"markup": True}, + ) dataset.save(data) @@ -592,7 +600,7 @@ def add( Example: :: - >>> from kedro.extras.datasets.pandas import CSVDataSet + >>> from kedro_datasets.pandas import CSVDataSet >>> >>> io = DataCatalog(data_sets={ >>> 'cars': CSVDataSet(filepath="cars.csv") @@ -628,7 +636,7 @@ def add_all( Example: :: - >>> from kedro.extras.datasets.pandas import CSVDataSet, ParquetDataSet + >>> from kedro_datasets.pandas import CSVDataSet, ParquetDataSet >>> >>> io = DataCatalog(data_sets={ >>> "cars": CSVDataSet(filepath="cars.csv") diff --git a/kedro/io/partitioned_dataset.py b/kedro/io/partitioned_dataset.py index bccbd9e628..845c761c67 100644 --- a/kedro/io/partitioned_dataset.py +++ b/kedro/io/partitioned_dataset.py @@ -379,7 +379,7 @@ class IncrementalDataset(PartitionedDataset): >>> data_set.load() """ - DEFAULT_CHECKPOINT_TYPE = "kedro.extras.datasets.text.TextDataSet" + DEFAULT_CHECKPOINT_TYPE = "kedro_datasets.text.TextDataSet" # TODO: PartitionedDataset should move to kedro-datasets DEFAULT_CHECKPOINT_FILENAME = "CHECKPOINT" def __init__( # noqa: too-many-arguments diff --git a/kedro/pipeline/node.py b/kedro/pipeline/node.py index d9435308c1..2d3f2b5215 100644 --- a/kedro/pipeline/node.py +++ b/kedro/pipeline/node.py @@ -163,7 +163,7 @@ def __hash__(self): def __str__(self): def _set_to_str(xset): - return f"[{','.join(xset)}]" + return f"[{';'.join(xset)}]" out_str = _set_to_str(self.outputs) if self._outputs else "None" in_str = _set_to_str(self.inputs) if self._inputs else "None" @@ -353,7 +353,12 @@ def run(self, inputs: dict[str, Any] = None) -> dict[str, Any]: # purposely catch all exceptions except Exception as exc: - self._logger.error("Node '%s' failed with error: \n%s", str(self), str(exc)) + self._logger.error( + "Node %s failed with error: \n%s", + str(self), + str(exc), + extra={"markup": True}, + ) raise exc def _run_with_no_inputs(self, inputs: dict[str, Any]): diff --git a/kedro/templates/project/{{ cookiecutter.repo_name }}/.flake8 b/kedro/templates/project/{{ cookiecutter.repo_name }}/.flake8 deleted file mode 100644 index 63ea673001..0000000000 --- a/kedro/templates/project/{{ cookiecutter.repo_name }}/.flake8 +++ /dev/null @@ -1,3 +0,0 @@ -[flake8] -max-line-length=88 -extend-ignore=E203 diff --git a/kedro/templates/project/{{ cookiecutter.repo_name }}/README.md b/kedro/templates/project/{{ cookiecutter.repo_name }}/README.md index 07ae44d46c..19d9afc130 100644 --- a/kedro/templates/project/{{ cookiecutter.repo_name }}/README.md +++ b/kedro/templates/project/{{ cookiecutter.repo_name }}/README.md @@ -38,22 +38,15 @@ kedro run Have a look at the file `src/tests/test_run.py` for instructions on how to write your tests. You can run your tests as follows: ``` -kedro test +pytest ``` -To configure the coverage threshold, go to the `.coveragerc` file. +To configure the coverage threshold, look at the `.coveragerc` file. -## Project dependencies - -To generate or update the dependency requirements for your project: - -``` -kedro build-reqs -``` -This will `pip-compile` the contents of `src/requirements.txt` into a new file `src/requirements.lock`. You can see the output of the resolution by opening `src/requirements.lock`. +## Project dependencies -After this, if you'd like to update your project requirements, please update `src/requirements.txt` and re-run `kedro build-reqs`. +To see and update the dependency requirements for your project use `src/requirements.txt`. You can install the project requirements with `pip install -r src/requirements.txt`. [Further information about project dependencies](https://docs.kedro.org/en/stable/kedro_project_setup/dependencies.html#project-specific-dependencies) @@ -96,24 +89,8 @@ And if you want to run an IPython session: kedro ipython ``` -### How to convert notebook cells to nodes in a Kedro project -You can move notebook code over into a Kedro project structure using a mixture of [cell tagging](https://jupyter-notebook.readthedocs.io/en/stable/changelog.html#release-5-0-0) and Kedro CLI commands. - -By adding the `node` tag to a cell and running the command below, the cell's source code will be copied over to a Python file within `src//nodes/`: - -``` -kedro jupyter convert -``` -> *Note:* The name of the Python file matches the name of the original notebook. - -Alternatively, you may want to transform all your notebooks in one go. Run the following command to convert all notebook files found in the project root directory and under any of its sub-folders: - -``` -kedro jupyter convert --all -``` - ### How to ignore notebook output cells in `git` -To automatically strip out all output cell contents before committing to `git`, you can run `kedro activate-nbstripout`. This will add a hook in `.git/config` which will run `nbstripout` before anything is committed to `git`. +To automatically strip out all output cell contents before committing to `git`, you can use tools like [`nbstripout`](https://github.com/kynan/nbstripout). For example, you can add a hook in `.git/config` with `nbstripout --install`. This will run `nbstripout` before anything is committed to `git`. > *Note:* Your output cells will be retained locally. diff --git a/kedro/templates/project/{{ cookiecutter.repo_name }}/conf/base/logging.yml b/kedro/templates/project/{{ cookiecutter.repo_name }}/conf/base/logging.yml deleted file mode 100644 index c6a6fc7057..0000000000 --- a/kedro/templates/project/{{ cookiecutter.repo_name }}/conf/base/logging.yml +++ /dev/null @@ -1,41 +0,0 @@ -version: 1 - -disable_existing_loggers: False - -formatters: - simple: - format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - -handlers: - console: - class: logging.StreamHandler - level: INFO - formatter: simple - stream: ext://sys.stdout - - info_file_handler: - class: logging.handlers.RotatingFileHandler - level: INFO - formatter: simple - filename: info.log - maxBytes: 10485760 # 10MB - backupCount: 20 - encoding: utf8 - delay: True - - rich: - class: kedro.logging.RichHandler - rich_tracebacks: True - # Advance options for customisation. - # See https://docs.kedro.org/en/stable/logging/logging.html#project-side-logging-configuration - # tracebacks_show_locals: False - -loggers: - kedro: - level: INFO - - {{ cookiecutter.python_package }}: - level: INFO - -root: - handlers: [rich, info_file_handler] diff --git a/kedro/templates/project/{{ cookiecutter.repo_name }}/pyproject.toml b/kedro/templates/project/{{ cookiecutter.repo_name }}/pyproject.toml index 7ae06368bd..48be962dc7 100644 --- a/kedro/templates/project/{{ cookiecutter.repo_name }}/pyproject.toml +++ b/kedro/templates/project/{{ cookiecutter.repo_name }}/pyproject.toml @@ -3,9 +3,6 @@ package_name = "{{ cookiecutter.python_package }}" project_name = "{{ cookiecutter.project_name }}" kedro_init_version = "{{ cookiecutter.kedro_version }}" -[tool.isort] -profile = "black" - [tool.pytest.ini_options] addopts = """ --cov-report term-missing \ diff --git a/kedro/templates/project/{{ cookiecutter.repo_name }}/src/pyproject.toml b/kedro/templates/project/{{ cookiecutter.repo_name }}/src/pyproject.toml index de4410c1af..5157ec7b3b 100644 --- a/kedro/templates/project/{{ cookiecutter.repo_name }}/src/pyproject.toml +++ b/kedro/templates/project/{{ cookiecutter.repo_name }}/src/pyproject.toml @@ -19,7 +19,6 @@ docs = [ "sphinx~=3.4.3", "sphinx_rtd_theme==0.5.1", "nbsphinx==0.8.1", - "nbstripout~=0.4", "sphinx-autodoc-typehints==1.11.1", "sphinx_copybutton==0.3.1", "ipykernel>=5.3, <7.0", diff --git a/kedro/templates/project/{{ cookiecutter.repo_name }}/src/requirements.txt b/kedro/templates/project/{{ cookiecutter.repo_name }}/src/requirements.txt index f7b52ec2c9..02ae471d84 100644 --- a/kedro/templates/project/{{ cookiecutter.repo_name }}/src/requirements.txt +++ b/kedro/templates/project/{{ cookiecutter.repo_name }}/src/requirements.txt @@ -1,14 +1,11 @@ black~=22.0 -flake8>=3.7.9, <5.0 ipython>=7.31.1, <8.0; python_version < '3.8' ipython~=8.10; python_version >= '3.8' -isort~=5.0 jupyter~=1.0 jupyterlab_server>=2.11.1, <2.16.0 jupyterlab~=3.0, <3.6.0 kedro~={{ cookiecutter.kedro_version }} kedro-telemetry~=0.2.0 -nbstripout~=0.4 pytest-cov~=3.0 pytest-mock>=1.7.1, <2.0 pytest~=7.2 diff --git a/kedro/templates/project/{{ cookiecutter.repo_name }}/src/tests/test_run.py b/kedro/templates/project/{{ cookiecutter.repo_name }}/src/tests/test_run.py index 785c5a40b9..eb57d1908e 100644 --- a/kedro/templates/project/{{ cookiecutter.repo_name }}/src/tests/test_run.py +++ b/kedro/templates/project/{{ cookiecutter.repo_name }}/src/tests/test_run.py @@ -5,7 +5,7 @@ project's structure, and in files named test_*.py. They are simply functions named ``test_*`` which test a unit of logic. -To run the tests, run ``kedro test`` from the project root directory. +To run the tests, run ``pytest`` from the project root directory. """ from pathlib import Path diff --git a/pyproject.toml b/pyproject.toml index 5ed2bfa43d..11a5679e3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [ {name = "Kedro"} ] description = "Kedro helps you build production-ready data and analytics pipelines" -requires-python = ">=3.7" +requires-python = ">=3.8" dependencies = [ "anyconfig>=0.10.0", "attrs>=21.3", @@ -21,15 +21,15 @@ dependencies = [ "dynaconf>=3.1.2,<4.0", "fsspec>=2021.4", "gitpython>=3.0", - "importlib_metadata>=3.6,<5.0; python_version < '3.8'", # The "selectable" entry points were introduced in `importlib_metadata` 3.6 and Python 3.10. Bandit on Python 3.7 relies on a library with `importlib_metadata` < 5.0 "importlib-metadata>=3.6,<7.0; python_version >= '3.8'", "importlib_resources>=1.3,<7.0", # The `files()` API was introduced in `importlib_resources` 1.3 and Python 3.9. "jmespath>=0.9.5", + "kedro-datasets", "more_itertools>=8.14.0", "omegaconf>=2.1.1", "parse>=1.19.0", "pip-tools>=6.5", - "pluggy>=1.0,<1.3", # TODO: Uncap when dropping Python 3.7 support, see https://github.com/kedro-org/kedro/issues/2979 + "pluggy>=1.0", "PyYAML>=4.2,<7.0", "rich>=12.0,<14.0", "rope>=0.21,<2.0", # subject to LGPLv3 license @@ -47,13 +47,62 @@ keywords = [ license = {text = "Apache Software License (Apache 2.0)"} classifiers = [ "Development Status :: 4 - Beta", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ] -dynamic = ["readme", "version", "optional-dependencies"] +dynamic = ["readme", "version"] + +[project.optional-dependencies] +test = [ + "bandit>=1.6.2, <2.0", + "behave==1.2.6", + "blacken-docs==1.9.2", + "black~=22.0", + "coverage[toml]", + "fsspec<2023.9", # Temporary, newer version causing "test_no_versions_with_cloud_protocol" to fail + "import-linter[toml]==1.8.0", + "ipython>=7.31.1, <8.0; python_version < '3.8'", + "ipython~=8.10; python_version >= '3.8'", + "Jinja2<3.1.0", + "jupyterlab_server>=2.11.1", + "jupyterlab~=3.0", + "jupyter~=1.0", + "memory_profiler>=0.50.0, <1.0", + "moto==1.3.7; python_version < '3.10'", + "moto==4.1.12; python_version >= '3.10'", + "pandas~=2.0", + "pre-commit>=2.9.2, <3.0", # The hook `mypy` requires pre-commit version 2.9.2. + "pyarrow>=1.0; python_version < '3.11'", + "pyarrow>=7.0; python_version >= '3.11'", # Adding to avoid numpy build errors + "pyproj~=3.0", + "pytest-cov~=3.0", + "pytest-mock>=1.7.1, <2.0", + "pytest-xdist[psutil]~=2.2.1", + "pytest~=7.2", + "s3fs>=0.3.0, <0.5", # Needs to be at least 0.3.0 to make use of `cachable` attribute on S3FileSystem. + "semver", + "trufflehog~=2.1", +] +docs = [ + # docutils>=0.17 changed the HTML + # see https://github.com/readthedocs/sphinx_rtd_theme/issues/1115 + "docutils==0.16", + "sphinx~=5.3.0", + "sphinx_rtd_theme==1.2.0", + # Regression on sphinx-autodoc-typehints 1.21 + # that creates some problematic docstrings + "sphinx-autodoc-typehints==1.20.2", + "sphinx_copybutton==0.3.1", + "sphinx-notfound-page", + "ipykernel>=5.3, <7.0", + "sphinxcontrib-mermaid~=0.7.1", + "myst-parser~=1.0.0", + "Jinja2<3.1.0", + "kedro-datasets[all]~=1.7.0", +] +all = [ "kedro[test,docs]" ] [project.urls] Homepage = "https://kedro.org" @@ -77,40 +126,6 @@ version = {attr = "kedro.__version__"} [tool.black] exclude = "/templates/|^features/steps/test_starter" -[tool.isort] -profile = "black" - - -[tool.pylint] -[tool.pylint.master] -ignore = "CVS" -ignore-patterns = "kedro/templates/*" -load-plugins = [ - "pylint.extensions.docparams", - "pylint.extensions.no_self_use" -] -extension-pkg-whitelist = "cv2" -unsafe-load-any-extension = false -[tool.pylint.messages_control] -disable = [ - "ungrouped-imports", - "duplicate-code", - "wrong-import-order", # taken care of by isort -] -enable = ["useless-suppression"] -[tool.pylint.refactoring] -max-nested-blocks = 5 -[tool.pylint.format] -indent-after-paren=4 -indent-string=" " -[tool.pylint.miscellaneous] -notes = [ - "FIXME", - "XXX" -] -[tool.pylint.design] -min-public-methods = 1 - [tool.coverage.report] fail_under = 100 show_missing = true @@ -120,9 +135,9 @@ omit = [ "kedro/extras/extensions/ipython.py", "kedro/framework/cli/hooks/specs.py", "kedro/framework/hooks/specs.py", - "kedro/extras/datasets/tensorflow/*", - "kedro/extras/datasets/holoviews/*", - "tests/*" + "tests/*", + "kedro/io/core.py", # TODO: temp fix for removing datasets, need to resolve this before 0.19 + "kedro/runner/parallel_runner.py" ] exclude_lines = ["pragma: no cover", "raise NotImplementedError"] @@ -156,7 +171,6 @@ layers = [ "framework.context", "framework.project", "runner", - "extras.datasets", "io", "pipeline", "config" @@ -184,7 +198,6 @@ forbidden_modules = [ "kedro.runner", "kedro.io", "kedro.pipeline", - "kedro.extras.datasets" ] [[tool.importlinter.contracts]] @@ -194,7 +207,6 @@ source_modules = [ "kedro.runner", "kedro.io", "kedro.pipeline", - "kedro.extras.datasets" ] forbidden_modules = [ "kedro.config" diff --git a/setup.py b/setup.py index afea8c9587..173e0e8ce1 100644 --- a/setup.py +++ b/setup.py @@ -1,14 +1,7 @@ from glob import glob -from itertools import chain from setuptools import setup -# at least 1.3 to be able to use XMLDataSet and pandas integration with fsspec -PANDAS = "pandas~=1.3" -SPARK = "pyspark>=2.2, <3.4" -HDFS = "hdfs>=2.5.8, <3.0" -S3FS = "s3fs>=0.3.0, <0.5" - template_files = [] for pattern in ["**/*", "**/.*", "**/.*/**", "**/.*/.**"]: template_files.extend( @@ -18,196 +11,8 @@ ] ) - -def _collect_requirements(requires): - return sorted(set(chain.from_iterable(requires.values()))) - - -api_require = {"api.APIDataSet": ["requests~=2.20"]} -biosequence_require = {"biosequence.BioSequenceDataSet": ["biopython~=1.73"]} -dask_require = {"dask.ParquetDataSet": ["dask[complete]~=2021.10", "triad>=0.6.7, <1.0"]} -geopandas_require = { - "geopandas.GeoJSONDataSet": ["geopandas>=0.6.0, <1.0", "pyproj~=3.0"] -} -matplotlib_require = {"matplotlib.MatplotlibWriter": ["matplotlib>=3.0.3, <4.0"]} -holoviews_require = {"holoviews.HoloviewsWriter": ["holoviews>=1.13.0"]} -networkx_require = {"networkx.NetworkXDataSet": ["networkx~=2.4"]} -pandas_require = { - "pandas.CSVDataSet": [PANDAS], - "pandas.ExcelDataSet": [PANDAS, "openpyxl>=3.0.6, <4.0"], - "pandas.FeatherDataSet": [PANDAS], - "pandas.GBQTableDataSet": [PANDAS, "pandas-gbq>=0.12.0, <0.18.0"], - "pandas.GBQQueryDataSet": [PANDAS, "pandas-gbq>=0.12.0, <0.18.0"], - "pandas.HDFDataSet": [ - PANDAS, - "tables~=3.6.0; platform_system == 'Windows'", - "tables~=3.6; platform_system != 'Windows'", - ], - "pandas.JSONDataSet": [PANDAS], - "pandas.ParquetDataSet": [PANDAS, "pyarrow>=1.0, <7.0"], - "pandas.SQLTableDataSet": [PANDAS, "SQLAlchemy~=1.2"], - "pandas.SQLQueryDataSet": [PANDAS, "SQLAlchemy~=1.2"], - "pandas.XMLDataSet": [PANDAS, "lxml~=4.6"], - "pandas.GenericDataSet": [PANDAS], -} -pickle_require = {"pickle.PickleDataSet": ["compress-pickle[lz4]~=2.1.0"]} -pillow_require = {"pillow.ImageDataSet": ["Pillow~=9.0"]} -video_require = { - "video.VideoDataSet": ["opencv-python~=4.5.5.64"] -} -plotly_require = { - "plotly.PlotlyDataSet": [PANDAS, "plotly>=4.8.0, <6.0"], - "plotly.JSONDataSet": ["plotly>=4.8.0, <6.0"], -} -redis_require = {"redis.PickleDataSet": ["redis~=4.1"]} -spark_require = { - "spark.SparkDataSet": [SPARK, HDFS, S3FS], - "spark.SparkHiveDataSet": [SPARK, HDFS, S3FS], - "spark.SparkJDBCDataSet": [SPARK, HDFS, S3FS], - "spark.DeltaTableDataSet": [SPARK, HDFS, S3FS, "delta-spark>=1.0, <3.0"], -} -svmlight_require = {"svmlight.SVMLightDataSet": ["scikit-learn~=1.0.2", "scipy~=1.7.3"]} -tensorflow_required = { - "tensorflow.TensorflowModelDataset": [ - # currently only TensorFlow V2 supported for saving and loading. - # V1 requires HDF5 and serialises differently - "tensorflow~=2.0; platform_system != 'Darwin' or platform_machine != 'arm64'", - # https://developer.apple.com/metal/tensorflow-plugin/ - "tensorflow-macos~=2.0; platform_system == 'Darwin' and platform_machine == 'arm64'", - ] -} -yaml_require = {"yaml.YAMLDataSet": [PANDAS, "PyYAML>=4.2, <7.0"]} - -extras_require = { - "api": _collect_requirements(api_require), - "biosequence": _collect_requirements(biosequence_require), - "dask": _collect_requirements(dask_require), - "docs": [ - # docutils>=0.17 changed the HTML - # see https://github.com/readthedocs/sphinx_rtd_theme/issues/1115 - "docutils==0.16", - "sphinx~=5.3.0", - "sphinx_rtd_theme==1.2.0", - # Regression on sphinx-autodoc-typehints 1.21 - # that creates some problematic docstrings - "sphinx-autodoc-typehints==1.20.2", - "sphinx_copybutton==0.3.1", - "sphinx-notfound-page", - "ipykernel>=5.3, <7.0", - "sphinxcontrib-mermaid~=0.7.1", - "myst-parser~=1.0.0", - "Jinja2<3.1.0", - "kedro-datasets[all]~=1.7.0", - ], - "geopandas": _collect_requirements(geopandas_require), - "matplotlib": _collect_requirements(matplotlib_require), - "holoviews": _collect_requirements(holoviews_require), - "networkx": _collect_requirements(networkx_require), - "pandas": _collect_requirements(pandas_require), - "pickle": _collect_requirements(pickle_require), - "pillow": _collect_requirements(pillow_require), - "video": _collect_requirements(video_require), - "plotly": _collect_requirements(plotly_require), - "redis": _collect_requirements(redis_require), - "spark": _collect_requirements(spark_require), - "svmlight": _collect_requirements(svmlight_require), - "tensorflow": _collect_requirements(tensorflow_required), - "yaml": _collect_requirements(yaml_require), - **api_require, - **biosequence_require, - **dask_require, - **geopandas_require, - **matplotlib_require, - **holoviews_require, - **networkx_require, - **pandas_require, - **pickle_require, - **pillow_require, - **video_require, - **plotly_require, - **spark_require, - **svmlight_require, - **tensorflow_required, - **yaml_require, -} - -extras_require["all"] = _collect_requirements(extras_require) -extras_require["test"] = [ - "adlfs>=2021.7.1, <=2022.2; python_version == '3.7'", - "adlfs~=2023.1; python_version >= '3.8'", - "bandit>=1.6.2, <2.0", - "behave==1.2.6", - "biopython~=1.73", - "blacken-docs==1.9.2", - "black~=22.0", - "compress-pickle[lz4]~=2.1.0", - "coverage[toml]", - "dask[complete]~=2021.10", # pinned by Snyk to avoid a vulnerability - "delta-spark>=1.2.1; python_version >= '3.11'", # 1.2.0 has a bug that breaks some of our tests: https://github.com/delta-io/delta/issues/1070 - "delta-spark~=1.2.1; python_version < '3.11'", - "dill~=0.3.1", - "filelock>=3.4.0, <4.0", - "gcsfs>=2021.4, <=2023.1; python_version == '3.7'", - "gcsfs>=2023.1, <2023.3; python_version >= '3.8'", - "geopandas>=0.6.0, <1.0", - "hdfs>=2.5.8, <3.0", - "holoviews>=1.13.0", - "import-linter[toml]==1.8.0", - "ipython>=7.31.1, <8.0; python_version < '3.8'", - "ipython~=8.10; python_version >= '3.8'", - "isort~=5.0", - "Jinja2<3.1.0", - "joblib>=0.14", - "jupyterlab_server>=2.11.1, <2.16.0", # 2.16.0 requires importlib_metedata >= 4.8.3 which conflicts with flake8 requirement - "jupyterlab~=3.0, <3.6.0", # 3.6.0 requires jupyterlab_server~=2.19 - "jupyter~=1.0", - "lxml~=4.6", - "matplotlib>=3.0.3, <3.4; python_version < '3.10'", # 3.4.0 breaks holoviews - "matplotlib>=3.5, <3.6; python_version >= '3.10'", - "memory_profiler>=0.50.0, <1.0", - "moto==1.3.7; python_version < '3.10'", - "moto==4.1.12; python_version >= '3.10'", - "networkx~=2.4", - "opencv-python~=4.5.5.64", - "openpyxl>=3.0.3, <4.0", - "pandas-gbq>=0.12.0, <0.18.0; python_version < '3.11'", - "pandas-gbq>=0.18.0; python_version >= '3.11'", - "pandas~=1.3 # 1.3 for read_xml/to_xml", - "Pillow~=9.0", - "plotly>=4.8.0, <6.0", - "pre-commit>=2.9.2, <3.0", # The hook `mypy` requires pre-commit version 2.9.2. - "pyarrow>=1.0; python_version < '3.11'", - "pyarrow>=7.0; python_version >= '3.11'", # Adding to avoid numpy build errors - "pylint>=2.17.0, <3.0", - "pyproj~=3.0", - "pyspark>=2.2, <3.4; python_version < '3.11'", - "pyspark>=3.4; python_version >= '3.11'", - "pytest-cov~=3.0", - "pytest-mock>=1.7.1, <2.0", - "pytest-xdist[psutil]~=2.2.1", - "pytest~=7.2", - "redis~=4.1", - "requests-mock~=1.6", - "requests~=2.20", - "s3fs>=0.3.0, <0.5", # Needs to be at least 0.3.0 to make use of `cachable` attribute on S3FileSystem. - "scikit-learn>=1.0.2,<2", - "scipy>=1.7.3", - "semver", - "SQLAlchemy~=1.2", - "tables~=3.6.0; platform_system == 'Windows' and python_version<'3.8'", - "tables~=3.8.0; platform_system == 'Windows' and python_version>='3.8'", # Import issues with python 3.8 with pytables pinning to 3.8.0 fixes this https://github.com/PyTables/PyTables/issues/933#issuecomment-1555917593 - "tables~=3.6; platform_system != 'Windows'", - "tensorflow~=2.0; platform_system != 'Darwin' or platform_machine != 'arm64'", - # https://developer.apple.com/metal/tensorflow-plugin/ - "tensorflow-macos~=2.0; platform_system == 'Darwin' and platform_machine == 'arm64'", - "triad>=0.6.7, <1.0", - "trufflehog~=2.1", - "xlsxwriter~=1.0", -] - setup( package_data={ "kedro": ["py.typed"] + template_files }, - extras_require=extras_require, ) diff --git a/tests/config/test_config.py b/tests/config/test_config.py index 110d3de692..fd34f8edf8 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -257,7 +257,6 @@ def test_no_files_found(self, tmp_path): def test_key_not_found_dict_get(self, tmp_path): """Check the error if no config files satisfy a given pattern""" with pytest.raises(KeyError): - # pylint: disable=expression-not-assigned ConfigLoader(str(tmp_path), _DEFAULT_RUN_ENV)["non-existent-pattern"] @use_config_dir @@ -271,7 +270,6 @@ def test_no_files_found_dict_get(self, tmp_path): r"\[\'credentials\*\', \'credentials\*/\**\', \'\**/credentials\*\'\]" ) with pytest.raises(MissingConfigException, match=pattern): - # pylint: disable=expression-not-assigned ConfigLoader(str(tmp_path), _DEFAULT_RUN_ENV)["credentials"] def test_duplicate_paths(self, tmp_path, caplog): diff --git a/tests/config/test_omegaconf_config.py b/tests/config/test_omegaconf_config.py index 5f2fa2e67f..4a99458f19 100644 --- a/tests/config/test_omegaconf_config.py +++ b/tests/config/test_omegaconf_config.py @@ -1,4 +1,3 @@ -# pylint: disable=expression-not-assigned, pointless-statement from __future__ import annotations import configparser @@ -518,7 +517,7 @@ def test_env_resolver_is_registered_after_loading(self, tmp_path): @use_config_dir def test_load_config_from_tar_file(self, tmp_path): - subprocess.run( # pylint: disable=subprocess-run-check + subprocess.run( [ "tar", "--exclude=local/*.yml", @@ -588,13 +587,11 @@ def test_runtime_params_not_propogate_non_parameters_config(self, tmp_path): parameters = conf["parameters"] catalog = conf["catalog"] credentials = conf["credentials"] - logging = conf["logging"] spark = conf["spark"] assert key in parameters assert key not in catalog assert key not in credentials - assert key not in logging assert key not in spark def test_ignore_hidden_keys(self, tmp_path): diff --git a/tests/extras/__init__.py b/tests/extras/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/__init__.py b/tests/extras/datasets/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/api/__init__.py b/tests/extras/datasets/api/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/api/test_api_dataset.py b/tests/extras/datasets/api/test_api_dataset.py deleted file mode 100644 index f08bd41b92..0000000000 --- a/tests/extras/datasets/api/test_api_dataset.py +++ /dev/null @@ -1,170 +0,0 @@ -# pylint: disable=no-member -import json -import socket - -import pytest -import requests -import requests_mock - -from kedro.extras.datasets.api import APIDataSet -from kedro.io.core import DatasetError - -POSSIBLE_METHODS = ["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"] - -TEST_URL = "http://example.com/api/test" -TEST_TEXT_RESPONSE_DATA = "This is a response." -TEST_JSON_RESPONSE_DATA = [{"key": "value"}] - -TEST_PARAMS = {"param": "value"} -TEST_URL_WITH_PARAMS = TEST_URL + "?param=value" - -TEST_HEADERS = {"key": "value"} - - -@pytest.mark.parametrize("method", POSSIBLE_METHODS) -class TestAPIDataSet: - @pytest.fixture - def requests_mocker(self): - with requests_mock.Mocker() as mock: - yield mock - - def test_successfully_load_with_response(self, requests_mocker, method): - api_data_set = APIDataSet( - url=TEST_URL, method=method, params=TEST_PARAMS, headers=TEST_HEADERS - ) - requests_mocker.register_uri( - method, - TEST_URL_WITH_PARAMS, - headers=TEST_HEADERS, - text=TEST_TEXT_RESPONSE_DATA, - ) - - response = api_data_set.load() - assert isinstance(response, requests.Response) - assert response.text == TEST_TEXT_RESPONSE_DATA - - def test_successful_json_load_with_response(self, requests_mocker, method): - api_data_set = APIDataSet( - url=TEST_URL, - method=method, - json=TEST_JSON_RESPONSE_DATA, - headers=TEST_HEADERS, - ) - requests_mocker.register_uri( - method, - TEST_URL, - headers=TEST_HEADERS, - text=json.dumps(TEST_JSON_RESPONSE_DATA), - ) - - response = api_data_set.load() - assert isinstance(response, requests.Response) - assert response.json() == TEST_JSON_RESPONSE_DATA - - def test_http_error(self, requests_mocker, method): - api_data_set = APIDataSet( - url=TEST_URL, method=method, params=TEST_PARAMS, headers=TEST_HEADERS - ) - requests_mocker.register_uri( - method, - TEST_URL_WITH_PARAMS, - headers=TEST_HEADERS, - text="Nope, not found", - status_code=requests.codes.FORBIDDEN, - ) - - with pytest.raises(DatasetError, match="Failed to fetch data"): - api_data_set.load() - - def test_socket_error(self, requests_mocker, method): - api_data_set = APIDataSet( - url=TEST_URL, method=method, params=TEST_PARAMS, headers=TEST_HEADERS - ) - requests_mocker.register_uri(method, TEST_URL_WITH_PARAMS, exc=socket.error) - - with pytest.raises(DatasetError, match="Failed to connect"): - api_data_set.load() - - def test_read_only_mode(self, method): - """ - Saving is disabled on the data set. - """ - api_data_set = APIDataSet(url=TEST_URL, method=method) - with pytest.raises(DatasetError, match="is a read only data set type"): - api_data_set.save({}) - - def test_exists_http_error(self, requests_mocker, method): - """ - In case of an unexpected HTTP error, - ``exists()`` should not silently catch it. - """ - api_data_set = APIDataSet( - url=TEST_URL, method=method, params=TEST_PARAMS, headers=TEST_HEADERS - ) - requests_mocker.register_uri( - method, - TEST_URL_WITH_PARAMS, - headers=TEST_HEADERS, - text="Nope, not found", - status_code=requests.codes.FORBIDDEN, - ) - with pytest.raises(DatasetError, match="Failed to fetch data"): - api_data_set.exists() - - def test_exists_ok(self, requests_mocker, method): - """ - If the file actually exists and server responds 200, - ``exists()`` should return True - """ - api_data_set = APIDataSet( - url=TEST_URL, method=method, params=TEST_PARAMS, headers=TEST_HEADERS - ) - requests_mocker.register_uri( - method, - TEST_URL_WITH_PARAMS, - headers=TEST_HEADERS, - text=TEST_TEXT_RESPONSE_DATA, - ) - - assert api_data_set.exists() - - def test_credentials_auth_error(self, method): - """ - If ``auth`` and ``credentials`` are both provided, - the constructor should raise a ValueError. - """ - with pytest.raises(ValueError, match="both auth and credentials"): - APIDataSet(url=TEST_URL, method=method, auth=[], credentials=[]) - - @pytest.mark.parametrize("auth_kwarg", ["auth", "credentials"]) - @pytest.mark.parametrize( - "auth_seq", - [ - ("username", "password"), - ["username", "password"], - (e for e in ["username", "password"]), # Generator. - ], - ) - def test_auth_sequence(self, requests_mocker, method, auth_seq, auth_kwarg): - """ - ``auth`` and ``credentials`` should be able to be any Iterable. - """ - kwargs = { - "url": TEST_URL, - "method": method, - "params": TEST_PARAMS, - "headers": TEST_HEADERS, - auth_kwarg: auth_seq, - } - - api_data_set = APIDataSet(**kwargs) - requests_mocker.register_uri( - method, - TEST_URL_WITH_PARAMS, - headers=TEST_HEADERS, - text=TEST_TEXT_RESPONSE_DATA, - ) - - response = api_data_set.load() - assert isinstance(response, requests.Response) - assert response.text == TEST_TEXT_RESPONSE_DATA diff --git a/tests/extras/datasets/bioinformatics/__init__.py b/tests/extras/datasets/bioinformatics/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/bioinformatics/test_biosequence_dataset.py b/tests/extras/datasets/bioinformatics/test_biosequence_dataset.py deleted file mode 100644 index b26271cb36..0000000000 --- a/tests/extras/datasets/bioinformatics/test_biosequence_dataset.py +++ /dev/null @@ -1,107 +0,0 @@ -from io import StringIO -from pathlib import PurePosixPath - -import pytest -from Bio import SeqIO -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.biosequence import BioSequenceDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER - -LOAD_ARGS = {"format": "fasta"} -SAVE_ARGS = {"format": "fasta"} - - -@pytest.fixture -def filepath_biosequence(tmp_path): - return str(tmp_path / "test.fasta") - - -@pytest.fixture -def biosequence_data_set(filepath_biosequence, fs_args): - return BioSequenceDataSet( - filepath=filepath_biosequence, - load_args=LOAD_ARGS, - save_args=SAVE_ARGS, - fs_args=fs_args, - ) - - -@pytest.fixture(scope="module") -def dummy_data(): - data = ">Alpha\nACCGGATGTA\n>Beta\nAGGCTCGGTTA\n" - return list(SeqIO.parse(StringIO(data), "fasta")) - - -class TestBioSequenceDataSet: - def test_save_and_load(self, biosequence_data_set, dummy_data): - """Test saving and reloading the data set.""" - biosequence_data_set.save(dummy_data) - reloaded = biosequence_data_set.load() - assert dummy_data[0].id, reloaded[0].id - assert dummy_data[0].seq, reloaded[0].seq - assert len(dummy_data) == len(reloaded) - assert biosequence_data_set._fs_open_args_load == {"mode": "r"} - assert biosequence_data_set._fs_open_args_save == {"mode": "w"} - - def test_exists(self, biosequence_data_set, dummy_data): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not biosequence_data_set.exists() - biosequence_data_set.save(dummy_data) - assert biosequence_data_set.exists() - - def test_load_save_args_propagation(self, biosequence_data_set): - """Test overriding the default load arguments.""" - for key, value in LOAD_ARGS.items(): - assert biosequence_data_set._load_args[key] == value - - for key, value in SAVE_ARGS.items(): - assert biosequence_data_set._save_args[key] == value - - @pytest.mark.parametrize( - "fs_args", - [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], - indirect=True, - ) - def test_open_extra_args(self, biosequence_data_set, fs_args): - assert biosequence_data_set._fs_open_args_load == fs_args["open_args_load"] - assert biosequence_data_set._fs_open_args_save == { - "mode": "w" - } # default unchanged - - def test_load_missing_file(self, biosequence_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set BioSequenceDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - biosequence_data_set.load() - - @pytest.mark.parametrize( - "filepath,instance_type", - [ - ("s3://bucket/file.fasta", S3FileSystem), - ("file:///tmp/test.fasta", LocalFileSystem), - ("/tmp/test.fasta", LocalFileSystem), - ("gcs://bucket/file.fasta", GCSFileSystem), - ("https://example.com/file.fasta", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, filepath, instance_type): - data_set = BioSequenceDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.fasta" - data_set = BioSequenceDataSet(filepath=filepath) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) diff --git a/tests/extras/datasets/conftest.py b/tests/extras/datasets/conftest.py deleted file mode 100644 index b9fddb3f88..0000000000 --- a/tests/extras/datasets/conftest.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -This file contains the fixtures that are reusable by any tests within -this directory. You don't need to import the fixtures as pytest will -discover them automatically. More info here: -https://docs.pytest.org/en/latest/fixture.html -""" - -from pytest import fixture - -from kedro.io.core import generate_timestamp - - -@fixture(params=[None]) -def load_version(request): - return request.param - - -@fixture(params=[None]) -def save_version(request): - return request.param or generate_timestamp() - - -@fixture(params=[None]) -def load_args(request): - return request.param - - -@fixture(params=[None]) -def save_args(request): - return request.param - - -@fixture(params=[None]) -def fs_args(request): - return request.param diff --git a/tests/extras/datasets/dask/__init__.py b/tests/extras/datasets/dask/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/dask/test_parquet_dataset.py b/tests/extras/datasets/dask/test_parquet_dataset.py deleted file mode 100644 index 597d8c40a4..0000000000 --- a/tests/extras/datasets/dask/test_parquet_dataset.py +++ /dev/null @@ -1,223 +0,0 @@ -import boto3 -import dask.dataframe as dd -import pandas as pd -import pyarrow as pa -import pyarrow.parquet as pq -import pytest -from moto import mock_s3 -from pandas.util.testing import assert_frame_equal -from s3fs import S3FileSystem - -from kedro.extras.datasets.dask import ParquetDataSet -from kedro.io import DatasetError - -FILE_NAME = "test.parquet" -BUCKET_NAME = "test_bucket" -AWS_CREDENTIALS = {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"} - -# Pathlib cannot be used since it strips out the second slash from "s3://" -S3_PATH = f"s3://{BUCKET_NAME}/{FILE_NAME}" - - -@pytest.fixture -def mocked_s3_bucket(): - """Create a bucket for testing using moto.""" - with mock_s3(): - conn = boto3.client( - "s3", - aws_access_key_id="fake_access_key", - aws_secret_access_key="fake_secret_key", - ) - conn.create_bucket(Bucket=BUCKET_NAME) - yield conn - - -@pytest.fixture -def dummy_dd_dataframe() -> dd.DataFrame: - df = pd.DataFrame( - {"Name": ["Alex", "Bob", "Clarke", "Dave"], "Age": [31, 12, 65, 29]} - ) - return dd.from_pandas(df, npartitions=2) - - -@pytest.fixture -def mocked_s3_object(tmp_path, mocked_s3_bucket, dummy_dd_dataframe: dd.DataFrame): - """Creates test data and adds it to mocked S3 bucket.""" - pandas_df = dummy_dd_dataframe.compute() - table = pa.Table.from_pandas(pandas_df) - temporary_path = tmp_path / FILE_NAME - pq.write_table(table, str(temporary_path)) - - mocked_s3_bucket.put_object( - Bucket=BUCKET_NAME, Key=FILE_NAME, Body=temporary_path.read_bytes() - ) - return mocked_s3_bucket - - -@pytest.fixture -def s3_data_set(load_args, save_args): - return ParquetDataSet( - filepath=S3_PATH, - credentials=AWS_CREDENTIALS, - load_args=load_args, - save_args=save_args, - ) - - -@pytest.fixture() -def s3fs_cleanup(): - # clear cache so we get a clean slate every time we instantiate a S3FileSystem - yield - S3FileSystem.cachable = False - - -@pytest.mark.usefixtures("s3fs_cleanup") -class TestParquetDataSet: - def test_incorrect_credentials_load(self): - """Test that incorrect credential keys won't instantiate dataset.""" - pattern = r"unexpected keyword argument" - with pytest.raises(DatasetError, match=pattern): - ParquetDataSet( - filepath=S3_PATH, - credentials={ - "client_kwargs": {"access_token": "TOKEN", "access_key": "KEY"} - }, - ).load().compute() - - @pytest.mark.parametrize("bad_credentials", [{"key": None, "secret": None}]) - def test_empty_credentials_load(self, bad_credentials): - parquet_data_set = ParquetDataSet(filepath=S3_PATH, credentials=bad_credentials) - pattern = r"Failed while loading data from data set ParquetDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - parquet_data_set.load().compute() - - def test_pass_credentials(self, mocker): - """Test that AWS credentials are passed successfully into boto3 - client instantiation on creating S3 connection.""" - client_mock = mocker.patch("botocore.session.Session.create_client") - s3_data_set = ParquetDataSet(filepath=S3_PATH, credentials=AWS_CREDENTIALS) - pattern = r"Failed while loading data from data set ParquetDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - s3_data_set.load().compute() - - assert client_mock.call_count == 1 - args, kwargs = client_mock.call_args_list[0] - assert args == ("s3",) - assert kwargs["aws_access_key_id"] == AWS_CREDENTIALS["key"] - assert kwargs["aws_secret_access_key"] == AWS_CREDENTIALS["secret"] - - @pytest.mark.usefixtures("mocked_s3_bucket") - def test_save_data(self, s3_data_set): - """Test saving the data to S3.""" - pd_data = pd.DataFrame( - {"col1": ["a", "b"], "col2": ["c", "d"], "col3": ["e", "f"]} - ) - dd_data = dd.from_pandas(pd_data, npartitions=2) - s3_data_set.save(dd_data) - loaded_data = s3_data_set.load() - assert_frame_equal(loaded_data.compute(), dd_data.compute()) - - @pytest.mark.usefixtures("mocked_s3_object") - def test_load_data(self, s3_data_set, dummy_dd_dataframe): - """Test loading the data from S3.""" - loaded_data = s3_data_set.load() - assert_frame_equal(loaded_data.compute(), dummy_dd_dataframe.compute()) - - @pytest.mark.usefixtures("mocked_s3_bucket") - def test_exists(self, s3_data_set, dummy_dd_dataframe): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not s3_data_set.exists() - s3_data_set.save(dummy_dd_dataframe) - assert s3_data_set.exists() - - def test_save_load_locally(self, tmp_path, dummy_dd_dataframe): - """Test loading the data locally.""" - file_path = str(tmp_path / "some" / "dir" / FILE_NAME) - data_set = ParquetDataSet(filepath=file_path) - - assert not data_set.exists() - data_set.save(dummy_dd_dataframe) - assert data_set.exists() - loaded_data = data_set.load() - dummy_dd_dataframe.compute().equals(loaded_data.compute()) - - @pytest.mark.parametrize( - "load_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_load_extra_params(self, s3_data_set, load_args): - """Test overriding the default load arguments.""" - for key, value in load_args.items(): - assert s3_data_set._load_args[key] == value - - @pytest.mark.parametrize( - "save_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_save_extra_params(self, s3_data_set, save_args): - """Test overriding the default save arguments.""" - s3_data_set._process_schema() - assert s3_data_set._save_args.get("schema") is None - - for key, value in save_args.items(): - assert s3_data_set._save_args[key] == value - - for key, value in s3_data_set.DEFAULT_SAVE_ARGS.items(): - assert s3_data_set._save_args[key] == value - - @pytest.mark.parametrize( - "save_args", - [{"schema": {"col1": "[[int64]]", "col2": "string"}}], - indirect=True, - ) - def test_save_extra_params_schema_dict(self, s3_data_set, save_args): - """Test setting the schema as dictionary of pyarrow column types - in save arguments.""" - - for key, value in save_args["schema"].items(): - assert s3_data_set._save_args["schema"][key] == value - - s3_data_set._process_schema() - - for field in s3_data_set._save_args["schema"].values(): - assert isinstance(field, pa.DataType) - - @pytest.mark.parametrize( - "save_args", - [ - { - "schema": { - "col1": "[[int64]]", - "col2": "string", - "col3": float, - "col4": pa.int64(), - } - } - ], - indirect=True, - ) - def test_save_extra_params_schema_dict_mixed_types(self, s3_data_set, save_args): - """Test setting the schema as dictionary of mixed value types - in save arguments.""" - - for key, value in save_args["schema"].items(): - assert s3_data_set._save_args["schema"][key] == value - - s3_data_set._process_schema() - - for field in s3_data_set._save_args["schema"].values(): - assert isinstance(field, pa.DataType) - - @pytest.mark.parametrize( - "save_args", - [{"schema": "c1:[int64],c2:int64"}], - indirect=True, - ) - def test_save_extra_params_schema_str_schema_fields(self, s3_data_set, save_args): - """Test setting the schema as string pyarrow schema (list of fields) - in save arguments.""" - - assert s3_data_set._save_args["schema"] == save_args["schema"] - - s3_data_set._process_schema() - - assert isinstance(s3_data_set._save_args["schema"], pa.Schema) diff --git a/tests/extras/datasets/email/__init__.py b/tests/extras/datasets/email/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/email/test_message_dataset.py b/tests/extras/datasets/email/test_message_dataset.py deleted file mode 100644 index 9eab39be4d..0000000000 --- a/tests/extras/datasets/email/test_message_dataset.py +++ /dev/null @@ -1,226 +0,0 @@ -from email.message import EmailMessage -from email.policy import default -from pathlib import Path, PurePosixPath - -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.email import EmailMessageDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version - - -@pytest.fixture -def filepath_message(tmp_path): - return (tmp_path / "test").as_posix() - - -@pytest.fixture -def message_data_set(filepath_message, load_args, save_args, fs_args): - return EmailMessageDataSet( - filepath=filepath_message, - load_args=load_args, - save_args=save_args, - fs_args=fs_args, - ) - - -@pytest.fixture -def versioned_message_data_set(filepath_message, load_version, save_version): - return EmailMessageDataSet( - filepath=filepath_message, version=Version(load_version, save_version) - ) - - -@pytest.fixture -def dummy_msg(): - string_to_write = "what would you do if you were invisible for one day????" - - # Create a text/plain message - msg = EmailMessage() - msg.set_content(string_to_write) - msg["Subject"] = "invisibility" - msg["From"] = '"sin studly17"' - msg["To"] = '"strong bad"' - - return msg - - -class TestEmailMessageDataSet: - def test_save_and_load(self, message_data_set, dummy_msg): - """Test saving and reloading the data set.""" - message_data_set.save(dummy_msg) - reloaded = message_data_set.load() - assert dummy_msg.__dict__ == reloaded.__dict__ - assert message_data_set._fs_open_args_load == {"mode": "r"} - assert message_data_set._fs_open_args_save == {"mode": "w"} - - def test_exists(self, message_data_set, dummy_msg): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not message_data_set.exists() - message_data_set.save(dummy_msg) - assert message_data_set.exists() - - @pytest.mark.parametrize( - "load_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_load_extra_params(self, message_data_set, load_args): - """Test overriding the default load arguments.""" - for key, value in load_args.items(): - assert message_data_set._load_args[key] == value - - @pytest.mark.parametrize( - "save_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_save_extra_params(self, message_data_set, save_args): - """Test overriding the default save arguments.""" - for key, value in save_args.items(): - assert message_data_set._save_args[key] == value - - @pytest.mark.parametrize( - "fs_args", - [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], - indirect=True, - ) - def test_open_extra_args(self, message_data_set, fs_args): - assert message_data_set._fs_open_args_load == fs_args["open_args_load"] - assert message_data_set._fs_open_args_save == {"mode": "w"} # default unchanged - - def test_load_missing_file(self, message_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set EmailMessageDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - message_data_set.load() - - @pytest.mark.parametrize( - "filepath,instance_type", - [ - ("s3://bucket/file", S3FileSystem), - ("file:///tmp/test", LocalFileSystem), - ("/tmp/test", LocalFileSystem), - ("gcs://bucket/file", GCSFileSystem), - ("https://example.com/file", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, filepath, instance_type): - data_set = EmailMessageDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test" - data_set = EmailMessageDataSet(filepath=filepath) - assert data_set._version_cache.currsize == 0 # no cache if unversioned - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - assert data_set._version_cache.currsize == 0 - - -class TestEmailMessageDataSetVersioned: - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "test" - ds = EmailMessageDataSet(filepath=filepath) - ds_versioned = EmailMessageDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - assert filepath in str(ds) - assert "version" not in str(ds) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "EmailMessageDataSet" in str(ds_versioned) - assert "EmailMessageDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - # Default parser_args - assert f"parser_args={{'policy': {default}}}" in str(ds) - assert f"parser_args={{'policy': {default}}}" in str(ds_versioned) - - def test_save_and_load(self, versioned_message_data_set, dummy_msg): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_message_data_set.save(dummy_msg) - reloaded = versioned_message_data_set.load() - assert dummy_msg.__dict__ == reloaded.__dict__ - - def test_no_versions(self, versioned_message_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for EmailMessageDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_message_data_set.load() - - def test_exists(self, versioned_message_data_set, dummy_msg): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_message_data_set.exists() - versioned_message_data_set.save(dummy_msg) - assert versioned_message_data_set.exists() - - def test_prevent_overwrite(self, versioned_message_data_set, dummy_msg): - """Check the error when attempting to override the data set if the - corresponding text file for a given save version already exists.""" - versioned_message_data_set.save(dummy_msg) - pattern = ( - r"Save path \'.+\' for EmailMessageDataSet\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_message_data_set.save(dummy_msg) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_message_data_set, load_version, save_version, dummy_msg - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - f"Save version '{save_version}' did not match " - f"load version '{load_version}' for " - r"EmailMessageDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_message_data_set.save(dummy_msg) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - EmailMessageDataSet( - filepath="https://example.com/file", version=Version(None, None) - ) - - def test_versioning_existing_dataset( - self, message_data_set, versioned_message_data_set, dummy_msg - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - message_data_set.save(dummy_msg) - assert message_data_set.exists() - assert message_data_set._filepath == versioned_message_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_message_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_message_data_set.save(dummy_msg) - - # Remove non-versioned dataset and try again - Path(message_data_set._filepath.as_posix()).unlink() - versioned_message_data_set.save(dummy_msg) - assert versioned_message_data_set.exists() diff --git a/tests/extras/datasets/geojson/__init__.py b/tests/extras/datasets/geojson/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/geojson/test_geojson_dataset.py b/tests/extras/datasets/geojson/test_geojson_dataset.py deleted file mode 100644 index 5a2669964c..0000000000 --- a/tests/extras/datasets/geojson/test_geojson_dataset.py +++ /dev/null @@ -1,232 +0,0 @@ -from pathlib import Path, PurePosixPath - -import geopandas as gpd -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from pandas.util.testing import assert_frame_equal -from s3fs import S3FileSystem -from shapely.geometry import Point - -from kedro.extras.datasets.geopandas import GeoJSONDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version, generate_timestamp - - -@pytest.fixture(params=[None]) -def load_version(request): - return request.param - - -@pytest.fixture(params=[None]) -def save_version(request): - return request.param or generate_timestamp() - - -@pytest.fixture -def filepath(tmp_path): - return (tmp_path / "test.geojson").as_posix() - - -@pytest.fixture(params=[None]) -def load_args(request): - return request.param - - -@pytest.fixture(params=[{"driver": "GeoJSON"}]) -def save_args(request): - return request.param - - -@pytest.fixture -def dummy_dataframe(): - return gpd.GeoDataFrame( - {"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}, - geometry=[Point(1, 1), Point(2, 2)], - ) - - -@pytest.fixture -def geojson_data_set(filepath, load_args, save_args, fs_args): - return GeoJSONDataSet( - filepath=filepath, load_args=load_args, save_args=save_args, fs_args=fs_args - ) - - -@pytest.fixture -def versioned_geojson_data_set(filepath, load_version, save_version): - return GeoJSONDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - - -class TestGeoJSONDataSet: - def test_save_and_load(self, geojson_data_set, dummy_dataframe): - """Test that saved and reloaded data matches the original one.""" - geojson_data_set.save(dummy_dataframe) - reloaded_df = geojson_data_set.load() - assert_frame_equal(reloaded_df, dummy_dataframe) - assert geojson_data_set._fs_open_args_load == {} - assert geojson_data_set._fs_open_args_save == {"mode": "wb"} - - @pytest.mark.parametrize("geojson_data_set", [{"index": False}], indirect=True) - def test_load_missing_file(self, geojson_data_set): - """Check the error while trying to load from missing source.""" - pattern = r"Failed while loading data from data set GeoJSONDataSet" - with pytest.raises(DatasetError, match=pattern): - geojson_data_set.load() - - def test_exists(self, geojson_data_set, dummy_dataframe): - """Test `exists` method invocation for both cases.""" - assert not geojson_data_set.exists() - geojson_data_set.save(dummy_dataframe) - assert geojson_data_set.exists() - - @pytest.mark.parametrize( - "load_args", [{"crs": "init:4326"}, {"crs": "init:2154", "driver": "GeoJSON"}] - ) - def test_load_extra_params(self, geojson_data_set, load_args): - """Test overriding default save args""" - for k, v in load_args.items(): - assert geojson_data_set._load_args[k] == v - - @pytest.mark.parametrize( - "save_args", [{"driver": "ESRI Shapefile"}, {"driver": "GPKG"}] - ) - def test_save_extra_params(self, geojson_data_set, save_args): - """Test overriding default save args""" - for k, v in save_args.items(): - assert geojson_data_set._save_args[k] == v - - @pytest.mark.parametrize( - "fs_args", - [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], - indirect=True, - ) - def test_open_extra_args(self, geojson_data_set, fs_args): - assert geojson_data_set._fs_open_args_load == fs_args["open_args_load"] - assert geojson_data_set._fs_open_args_save == {"mode": "wb"} - - @pytest.mark.parametrize( - "path,instance_type", - [ - ("s3://bucket/file.geojson", S3FileSystem), - ("/tmp/test.geojson", LocalFileSystem), - ("gcs://bucket/file.geojson", GCSFileSystem), - ("file:///tmp/file.geojson", LocalFileSystem), - ("https://example.com/file.geojson", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, path, instance_type): - geojson_data_set = GeoJSONDataSet(filepath=path) - assert isinstance(geojson_data_set._fs, instance_type) - - path = path.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(geojson_data_set._filepath) == path - assert isinstance(geojson_data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.geojson" - geojson_data_set = GeoJSONDataSet(filepath=filepath) - geojson_data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - -class TestGeoJSONDataSetVersioned: - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "test.geojson" - ds = GeoJSONDataSet(filepath=filepath) - ds_versioned = GeoJSONDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - assert filepath in str(ds) - assert "version" not in str(ds) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "GeoJSONDataSet" in str(ds_versioned) - assert "GeoJSONDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - - def test_save_and_load(self, versioned_geojson_data_set, dummy_dataframe): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_geojson_data_set.save(dummy_dataframe) - reloaded_df = versioned_geojson_data_set.load() - assert_frame_equal(reloaded_df, dummy_dataframe) - - def test_no_versions(self, versioned_geojson_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for GeoJSONDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_geojson_data_set.load() - - def test_exists(self, versioned_geojson_data_set, dummy_dataframe): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_geojson_data_set.exists() - versioned_geojson_data_set.save(dummy_dataframe) - assert versioned_geojson_data_set.exists() - - def test_prevent_override(self, versioned_geojson_data_set, dummy_dataframe): - """Check the error when attempt to override the same data set - version.""" - versioned_geojson_data_set.save(dummy_dataframe) - pattern = ( - r"Save path \'.+\' for GeoJSONDataSet\(.+\) must not " - r"exist if versioning is enabled" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_geojson_data_set.save(dummy_dataframe) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_geojson_data_set, load_version, save_version, dummy_dataframe - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for GeoJSONDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_geojson_data_set.save(dummy_dataframe) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - GeoJSONDataSet( - filepath="https://example/file.geojson", version=Version(None, None) - ) - - def test_versioning_existing_dataset( - self, geojson_data_set, versioned_geojson_data_set, dummy_dataframe - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - geojson_data_set.save(dummy_dataframe) - assert geojson_data_set.exists() - assert geojson_data_set._filepath == versioned_geojson_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_geojson_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_geojson_data_set.save(dummy_dataframe) - - # Remove non-versioned dataset and try again - Path(geojson_data_set._filepath.as_posix()).unlink() - versioned_geojson_data_set.save(dummy_dataframe) - assert versioned_geojson_data_set.exists() diff --git a/tests/extras/datasets/holoviews/__init__.py b/tests/extras/datasets/holoviews/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/holoviews/test_holoviews_writer.py b/tests/extras/datasets/holoviews/test_holoviews_writer.py deleted file mode 100644 index 24fb7f6c0f..0000000000 --- a/tests/extras/datasets/holoviews/test_holoviews_writer.py +++ /dev/null @@ -1,220 +0,0 @@ -import sys -from pathlib import Path, PurePosixPath - -import holoviews as hv -import pytest -from adlfs import AzureBlobFileSystem -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.holoviews import HoloviewsWriter -from kedro.io import DatasetError, Version -from kedro.io.core import PROTOCOL_DELIMITER - - -@pytest.fixture -def filepath_png(tmp_path): - return (tmp_path / "test.png").as_posix() - - -@pytest.fixture(scope="module") -def dummy_hv_object(): - return hv.Curve(range(10)) - - -@pytest.fixture -def hv_writer(filepath_png, save_args, fs_args): - return HoloviewsWriter(filepath_png, save_args=save_args, fs_args=fs_args) - - -@pytest.fixture -def versioned_hv_writer(filepath_png, load_version, save_version): - return HoloviewsWriter(filepath_png, version=Version(load_version, save_version)) - - -@pytest.mark.skipif( - sys.version_info.minor == 10, - reason="Python 3.10 needs matplotlib>=3.5 which breaks holoviews.", -) -class TestHoloviewsWriter: - def test_save_data(self, tmp_path, dummy_hv_object, hv_writer): - """Test saving Holoviews object.""" - hv_writer.save(dummy_hv_object) - - actual_filepath = Path(hv_writer._filepath.as_posix()) - test_filepath = tmp_path / "locally_saved.png" - hv.save(dummy_hv_object, test_filepath) - - assert actual_filepath.read_bytes() == test_filepath.read_bytes() - assert hv_writer._fs_open_args_save == {"mode": "wb"} - assert hv_writer._save_args == {"fmt": "png"} - - @pytest.mark.parametrize( - "fs_args", - [ - { - "storage_option": "value", - "open_args_save": {"mode": "w", "compression": "gzip"}, - } - ], - ) - def test_open_extra_args(self, tmp_path, fs_args, mocker): - fs_mock = mocker.patch("fsspec.filesystem") - writer = HoloviewsWriter(str(tmp_path), fs_args) - - fs_mock.assert_called_once_with("file", auto_mkdir=True, storage_option="value") - assert writer._fs_open_args_save == fs_args["open_args_save"] - - def test_load_fail(self, hv_writer): - pattern = r"Loading not supported for 'HoloviewsWriter'" - with pytest.raises(DatasetError, match=pattern): - hv_writer.load() - - def test_exists(self, dummy_hv_object, hv_writer): - assert not hv_writer.exists() - hv_writer.save(dummy_hv_object) - assert hv_writer.exists() - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.png" - data_set = HoloviewsWriter(filepath=filepath) - assert data_set._version_cache.currsize == 0 # no cache if unversioned - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - assert data_set._version_cache.currsize == 0 - - @pytest.mark.parametrize("save_args", [{"k1": "v1", "fmt": "svg"}], indirect=True) - def test_save_extra_params(self, hv_writer, save_args): - """Test overriding the default save arguments.""" - for key, value in save_args.items(): - assert hv_writer._save_args[key] == value - - @pytest.mark.parametrize( - "filepath,instance_type,credentials", - [ - ("s3://bucket/file.png", S3FileSystem, {}), - ("file:///tmp/test.png", LocalFileSystem, {}), - ("/tmp/test.png", LocalFileSystem, {}), - ("gcs://bucket/file.png", GCSFileSystem, {}), - ("https://example.com/file.png", HTTPFileSystem, {}), - ( - "abfs://bucket/file.png", - AzureBlobFileSystem, - {"account_name": "test", "account_key": "test"}, - ), - ], - ) - def test_protocol_usage(self, filepath, instance_type, credentials): - data_set = HoloviewsWriter(filepath=filepath, credentials=credentials) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - -@pytest.mark.skipif( - sys.version_info.minor == 10, - reason="Python 3.10 needs matplotlib>=3.5 which breaks holoviews.", -) -class TestHoloviewsWriterVersioned: - def test_version_str_repr(self, hv_writer, versioned_hv_writer): - """Test that version is in string representation of the class instance - when applicable.""" - - assert str(hv_writer._filepath) in str(hv_writer) - assert "version=" not in str(hv_writer) - assert "protocol" in str(hv_writer) - assert "save_args" in str(hv_writer) - - assert str(versioned_hv_writer._filepath) in str(versioned_hv_writer) - ver_str = f"version={versioned_hv_writer._version}" - assert ver_str in str(versioned_hv_writer) - assert "protocol" in str(versioned_hv_writer) - assert "save_args" in str(versioned_hv_writer) - - def test_prevent_overwrite(self, dummy_hv_object, versioned_hv_writer): - """Check the error when attempting to override the data set if the - corresponding file for a given save version already exists.""" - versioned_hv_writer.save(dummy_hv_object) - pattern = ( - r"Save path \'.+\' for HoloviewsWriter\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_hv_writer.save(dummy_hv_object) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, load_version, save_version, dummy_hv_object, versioned_hv_writer - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for HoloviewsWriter\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_hv_writer.save(dummy_hv_object) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - HoloviewsWriter( - filepath="https://example.com/file.png", version=Version(None, None) - ) - - def test_load_not_supported(self, versioned_hv_writer): - """Check the error if no versions are available for load.""" - pattern = ( - rf"Loading not supported for '{versioned_hv_writer.__class__.__name__}'" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_hv_writer.load() - - def test_exists(self, versioned_hv_writer, dummy_hv_object): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_hv_writer.exists() - versioned_hv_writer.save(dummy_hv_object) - assert versioned_hv_writer.exists() - - def test_save_data(self, versioned_hv_writer, dummy_hv_object, tmp_path): - """Test saving Holoviews object with enabled versioning.""" - versioned_hv_writer.save(dummy_hv_object) - - test_filepath = tmp_path / "test_image.png" - actual_filepath = Path(versioned_hv_writer._get_load_path().as_posix()) - - hv.save(dummy_hv_object, test_filepath) - - assert actual_filepath.read_bytes() == test_filepath.read_bytes() - - def test_versioning_existing_dataset( - self, hv_writer, versioned_hv_writer, dummy_hv_object - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - hv_writer.save(dummy_hv_object) - assert hv_writer.exists() - assert hv_writer._filepath == versioned_hv_writer._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_hv_writer._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_hv_writer.save(dummy_hv_object) - - # Remove non-versioned dataset and try again - Path(hv_writer._filepath.as_posix()).unlink() - versioned_hv_writer.save(dummy_hv_object) - assert versioned_hv_writer.exists() diff --git a/tests/extras/datasets/json/__init__.py b/tests/extras/datasets/json/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/json/test_json_dataset.py b/tests/extras/datasets/json/test_json_dataset.py deleted file mode 100644 index 531fd007b7..0000000000 --- a/tests/extras/datasets/json/test_json_dataset.py +++ /dev/null @@ -1,200 +0,0 @@ -from pathlib import Path, PurePosixPath - -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.json import JSONDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version - - -@pytest.fixture -def filepath_json(tmp_path): - return (tmp_path / "test.json").as_posix() - - -@pytest.fixture -def json_data_set(filepath_json, save_args, fs_args): - return JSONDataSet(filepath=filepath_json, save_args=save_args, fs_args=fs_args) - - -@pytest.fixture -def versioned_json_data_set(filepath_json, load_version, save_version): - return JSONDataSet( - filepath=filepath_json, version=Version(load_version, save_version) - ) - - -@pytest.fixture -def dummy_data(): - return {"col1": 1, "col2": 2, "col3": 3} - - -class TestJSONDataSet: - def test_save_and_load(self, json_data_set, dummy_data): - """Test saving and reloading the data set.""" - json_data_set.save(dummy_data) - reloaded = json_data_set.load() - assert dummy_data == reloaded - assert json_data_set._fs_open_args_load == {} - assert json_data_set._fs_open_args_save == {"mode": "w"} - - def test_exists(self, json_data_set, dummy_data): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not json_data_set.exists() - json_data_set.save(dummy_data) - assert json_data_set.exists() - - @pytest.mark.parametrize( - "save_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_save_extra_params(self, json_data_set, save_args): - """Test overriding the default save arguments.""" - for key, value in save_args.items(): - assert json_data_set._save_args[key] == value - - @pytest.mark.parametrize( - "fs_args", - [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], - indirect=True, - ) - def test_open_extra_args(self, json_data_set, fs_args): - assert json_data_set._fs_open_args_load == fs_args["open_args_load"] - assert json_data_set._fs_open_args_save == {"mode": "w"} # default unchanged - - def test_load_missing_file(self, json_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set JSONDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - json_data_set.load() - - @pytest.mark.parametrize( - "filepath,instance_type", - [ - ("s3://bucket/file.json", S3FileSystem), - ("file:///tmp/test.json", LocalFileSystem), - ("/tmp/test.json", LocalFileSystem), - ("gcs://bucket/file.json", GCSFileSystem), - ("https://example.com/file.json", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, filepath, instance_type): - data_set = JSONDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.json" - data_set = JSONDataSet(filepath=filepath) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - -class TestJSONDataSetVersioned: - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "test.json" - ds = JSONDataSet(filepath=filepath) - ds_versioned = JSONDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - assert filepath in str(ds) - assert "version" not in str(ds) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "JSONDataSet" in str(ds_versioned) - assert "JSONDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - # Default save_args - assert "save_args={'indent': 2}" in str(ds) - assert "save_args={'indent': 2}" in str(ds_versioned) - - def test_save_and_load(self, versioned_json_data_set, dummy_data): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_json_data_set.save(dummy_data) - reloaded = versioned_json_data_set.load() - assert dummy_data == reloaded - - def test_no_versions(self, versioned_json_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for JSONDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_json_data_set.load() - - def test_exists(self, versioned_json_data_set, dummy_data): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_json_data_set.exists() - versioned_json_data_set.save(dummy_data) - assert versioned_json_data_set.exists() - - def test_prevent_overwrite(self, versioned_json_data_set, dummy_data): - """Check the error when attempting to override the data set if the - corresponding json file for a given save version already exists.""" - versioned_json_data_set.save(dummy_data) - pattern = ( - r"Save path \'.+\' for JSONDataSet\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_json_data_set.save(dummy_data) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_json_data_set, load_version, save_version, dummy_data - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - f"Save version '{save_version}' did not match " - f"load version '{load_version}' for " - r"JSONDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_json_data_set.save(dummy_data) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - JSONDataSet( - filepath="https://example.com/file.json", version=Version(None, None) - ) - - def test_versioning_existing_dataset( - self, json_data_set, versioned_json_data_set, dummy_data - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - json_data_set.save(dummy_data) - assert json_data_set.exists() - assert json_data_set._filepath == versioned_json_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_json_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_json_data_set.save(dummy_data) - - # Remove non-versioned dataset and try again - Path(json_data_set._filepath.as_posix()).unlink() - versioned_json_data_set.save(dummy_data) - assert versioned_json_data_set.exists() diff --git a/tests/extras/datasets/libsvm/__init__.py b/tests/extras/datasets/libsvm/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/libsvm/test_svmlight_dataset.py b/tests/extras/datasets/libsvm/test_svmlight_dataset.py deleted file mode 100644 index 52bfba394d..0000000000 --- a/tests/extras/datasets/libsvm/test_svmlight_dataset.py +++ /dev/null @@ -1,214 +0,0 @@ -from pathlib import Path, PurePosixPath - -import numpy as np -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.svmlight import SVMLightDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version - - -@pytest.fixture -def filepath_svm(tmp_path): - return (tmp_path / "test.svm").as_posix() - - -@pytest.fixture -def svm_data_set(filepath_svm, save_args, load_args, fs_args): - return SVMLightDataSet( - filepath=filepath_svm, save_args=save_args, load_args=load_args, fs_args=fs_args - ) - - -@pytest.fixture -def versioned_svm_data_set(filepath_svm, load_version, save_version): - return SVMLightDataSet( - filepath=filepath_svm, version=Version(load_version, save_version) - ) - - -@pytest.fixture -def dummy_data(): - features = np.array([[1, 2, 10], [1, 0.4, 3.2], [0, 0, 0]]) - label = np.array([1, 0, 3]) - return features, label - - -class TestSVMLightDataSet: - def test_save_and_load(self, svm_data_set, dummy_data): - """Test saving and reloading the data set.""" - svm_data_set.save(dummy_data) - reloaded_features, reloaded_label = svm_data_set.load() - original_features, original_label = dummy_data - assert (original_features == reloaded_features).all() - assert (original_label == reloaded_label).all() - assert svm_data_set._fs_open_args_load == {"mode": "rb"} - assert svm_data_set._fs_open_args_save == {"mode": "wb"} - - def test_exists(self, svm_data_set, dummy_data): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not svm_data_set.exists() - svm_data_set.save(dummy_data) - assert svm_data_set.exists() - - @pytest.mark.parametrize( - "save_args", [{"zero_based": False, "comment": "comment"}], indirect=True - ) - def test_save_extra_save_args(self, svm_data_set, save_args): - """Test overriding the default save arguments.""" - for key, value in save_args.items(): - assert svm_data_set._save_args[key] == value - - @pytest.mark.parametrize( - "load_args", [{"zero_based": False, "n_features": 3}], indirect=True - ) - def test_save_extra_load_args(self, svm_data_set, load_args): - """Test overriding the default load arguments.""" - for key, value in load_args.items(): - assert svm_data_set._load_args[key] == value - - @pytest.mark.parametrize( - "fs_args", - [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], - indirect=True, - ) - def test_open_extra_args(self, svm_data_set, fs_args): - assert svm_data_set._fs_open_args_load == fs_args["open_args_load"] - assert svm_data_set._fs_open_args_save == {"mode": "wb"} # default unchanged - - def test_load_missing_file(self, svm_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set SVMLightDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - svm_data_set.load() - - @pytest.mark.parametrize( - "filepath,instance_type", - [ - ("s3://bucket/file.svm", S3FileSystem), - ("file:///tmp/test.svm", LocalFileSystem), - ("/tmp/test.svm", LocalFileSystem), - ("gcs://bucket/file.svm", GCSFileSystem), - ("https://example.com/file.svm", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, filepath, instance_type): - data_set = SVMLightDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.svm" - data_set = SVMLightDataSet(filepath=filepath) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - -class TestSVMLightDataSetVersioned: - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "test.svm" - ds = SVMLightDataSet(filepath=filepath) - ds_versioned = SVMLightDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - assert filepath in str(ds) - assert "version" not in str(ds) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "SVMLightDataSet" in str(ds_versioned) - assert "SVMLightDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - - def test_save_and_load(self, versioned_svm_data_set, dummy_data): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_svm_data_set.save(dummy_data) - reloaded_features, reloaded_label = versioned_svm_data_set.load() - original_features, original_label = dummy_data - assert (original_features == reloaded_features).all() - assert (original_label == reloaded_label).all() - - def test_no_versions(self, versioned_svm_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for SVMLightDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_svm_data_set.load() - - def test_exists(self, versioned_svm_data_set, dummy_data): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_svm_data_set.exists() - versioned_svm_data_set.save(dummy_data) - assert versioned_svm_data_set.exists() - - def test_prevent_overwrite(self, versioned_svm_data_set, dummy_data): - """Check the error when attempting to override the data set if the - corresponding json file for a given save version already exists.""" - versioned_svm_data_set.save(dummy_data) - pattern = ( - r"Save path \'.+\' for SVMLightDataSet\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_svm_data_set.save(dummy_data) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_svm_data_set, load_version, save_version, dummy_data - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - f"Save version '{save_version}' did not match " - f"load version '{load_version}' for " - r"SVMLightDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_svm_data_set.save(dummy_data) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - SVMLightDataSet( - filepath="https://example.com/file.svm", version=Version(None, None) - ) - - def test_versioning_existing_dataset( - self, svm_data_set, versioned_svm_data_set, dummy_data - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - svm_data_set.save(dummy_data) - assert svm_data_set.exists() - assert svm_data_set._filepath == versioned_svm_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_svm_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_svm_data_set.save(dummy_data) - - # Remove non-versioned dataset and try again - Path(svm_data_set._filepath.as_posix()).unlink() - versioned_svm_data_set.save(dummy_data) - assert versioned_svm_data_set.exists() diff --git a/tests/extras/datasets/matplotlib/__init__.py b/tests/extras/datasets/matplotlib/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/matplotlib/test_matplotlib_writer.py b/tests/extras/datasets/matplotlib/test_matplotlib_writer.py deleted file mode 100644 index e6ee5be83b..0000000000 --- a/tests/extras/datasets/matplotlib/test_matplotlib_writer.py +++ /dev/null @@ -1,436 +0,0 @@ -import json -from pathlib import Path - -import boto3 -import matplotlib -import matplotlib.pyplot as plt -import pytest -from moto import mock_s3 -from s3fs import S3FileSystem - -from kedro.extras.datasets.matplotlib import MatplotlibWriter -from kedro.io import DatasetError, Version - -BUCKET_NAME = "test_bucket" -AWS_CREDENTIALS = {"key": "testing", "secret": "testing"} -KEY_PATH = "matplotlib" -COLOUR_LIST = ["blue", "green", "red"] -FULL_PATH = f"s3://{BUCKET_NAME}/{KEY_PATH}" - -matplotlib.use("Agg") # Disable interactive mode - - -@pytest.fixture -def mock_single_plot(): - plt.plot([1, 2, 3], [4, 5, 6]) - plt.close("all") - return plt - - -@pytest.fixture -def mock_list_plot(): - plots_list = [] - colour = "red" - for index in range(5): # pylint: disable=unused-variable - plots_list.append(plt.figure()) - plt.plot([1, 2, 3], [4, 5, 6], color=colour) - plt.close("all") - return plots_list - - -@pytest.fixture -def mock_dict_plot(): - plots_dict = {} - for colour in COLOUR_LIST: - plots_dict[colour] = plt.figure() - plt.plot([1, 2, 3], [4, 5, 6], color=colour) - plt.close("all") - return plots_dict - - -@pytest.fixture -def mocked_s3_bucket(): - """Create a bucket for testing using moto.""" - with mock_s3(): - conn = boto3.client( - "s3", - aws_access_key_id="fake_access_key", - aws_secret_access_key="fake_secret_key", - ) - conn.create_bucket(Bucket=BUCKET_NAME) - yield conn - - -@pytest.fixture -def mocked_encrypted_s3_bucket(): - bucket_policy = { - "Version": "2012-10-17", - "Id": "PutObjPolicy", - "Statement": [ - { - "Sid": "DenyUnEncryptedObjectUploads", - "Effect": "Deny", - "Principal": "*", - "Action": "s3:PutObject", - "Resource": f"arn:aws:s3:::{BUCKET_NAME}/*", - "Condition": {"Null": {"s3:x-amz-server-side-encryption": "aws:kms"}}, - } - ], - } - bucket_policy = json.dumps(bucket_policy) - - with mock_s3(): - conn = boto3.client( - "s3", - aws_access_key_id="fake_access_key", - aws_secret_access_key="fake_secret_key", - ) - conn.create_bucket(Bucket=BUCKET_NAME) - conn.put_bucket_policy(Bucket=BUCKET_NAME, Policy=bucket_policy) - yield conn - - -@pytest.fixture() -def s3fs_cleanup(): - # clear cache for clean mocked s3 bucket each time - yield - S3FileSystem.cachable = False - - -@pytest.fixture(params=[False]) -def overwrite(request): - return request.param - - -@pytest.fixture -def plot_writer( - mocked_s3_bucket, fs_args, save_args, overwrite -): # pylint: disable=unused-argument - return MatplotlibWriter( - filepath=FULL_PATH, - credentials=AWS_CREDENTIALS, - fs_args=fs_args, - save_args=save_args, - overwrite=overwrite, - ) - - -@pytest.fixture -def versioned_plot_writer(tmp_path, load_version, save_version): - filepath = (tmp_path / "matplotlib.png").as_posix() - return MatplotlibWriter( - filepath=filepath, version=Version(load_version, save_version) - ) - - -@pytest.fixture(autouse=True) -def cleanup_plt(): - yield - plt.close("all") - - -class TestMatplotlibWriter: - @pytest.mark.parametrize("save_args", [{"k1": "v1"}], indirect=True) - def test_save_data( - self, tmp_path, mock_single_plot, plot_writer, mocked_s3_bucket, save_args - ): - """Test saving single matplotlib plot to S3.""" - plot_writer.save(mock_single_plot) - - download_path = tmp_path / "downloaded_image.png" - actual_filepath = tmp_path / "locally_saved.png" - - mock_single_plot.savefig(str(actual_filepath)) - - mocked_s3_bucket.download_file(BUCKET_NAME, KEY_PATH, str(download_path)) - - assert actual_filepath.read_bytes() == download_path.read_bytes() - assert plot_writer._fs_open_args_save == {"mode": "wb"} - for key, value in save_args.items(): - assert plot_writer._save_args[key] == value - - def test_list_save(self, tmp_path, mock_list_plot, plot_writer, mocked_s3_bucket): - """Test saving list of plots to S3.""" - - plot_writer.save(mock_list_plot) - - for index in range(5): - download_path = tmp_path / "downloaded_image.png" - actual_filepath = tmp_path / "locally_saved.png" - - mock_list_plot[index].savefig(str(actual_filepath)) - _key_path = f"{KEY_PATH}/{index}.png" - mocked_s3_bucket.download_file(BUCKET_NAME, _key_path, str(download_path)) - - assert actual_filepath.read_bytes() == download_path.read_bytes() - - def test_dict_save(self, tmp_path, mock_dict_plot, plot_writer, mocked_s3_bucket): - """Test saving dictionary of plots to S3.""" - - plot_writer.save(mock_dict_plot) - - for colour in COLOUR_LIST: - - download_path = tmp_path / "downloaded_image.png" - actual_filepath = tmp_path / "locally_saved.png" - - mock_dict_plot[colour].savefig(str(actual_filepath)) - - _key_path = f"{KEY_PATH}/{colour}" - - mocked_s3_bucket.download_file(BUCKET_NAME, _key_path, str(download_path)) - - assert actual_filepath.read_bytes() == download_path.read_bytes() - - @pytest.mark.parametrize( - "overwrite,expected_num_plots", [(False, 8), (True, 3)], indirect=["overwrite"] - ) - def test_overwrite( - self, - mock_list_plot, - mock_dict_plot, - plot_writer, - mocked_s3_bucket, - expected_num_plots, - ): - """Test saving dictionary of plots after list of plots to S3.""" - - plot_writer.save(mock_list_plot) - plot_writer.save(mock_dict_plot) - - response = mocked_s3_bucket.list_objects(Bucket=BUCKET_NAME) - saved_plots = {obj["Key"] for obj in response["Contents"]} - - assert {f"{KEY_PATH}/{colour}" for colour in COLOUR_LIST} <= saved_plots - assert len(saved_plots) == expected_num_plots - - def test_fs_args(self, tmp_path, mock_single_plot, mocked_encrypted_s3_bucket): - """Test writing to encrypted bucket.""" - normal_encryped_writer = MatplotlibWriter( - fs_args={"s3_additional_kwargs": {"ServerSideEncryption": "AES256"}}, - filepath=FULL_PATH, - credentials=AWS_CREDENTIALS, - ) - - normal_encryped_writer.save(mock_single_plot) - - download_path = tmp_path / "downloaded_image.png" - actual_filepath = tmp_path / "locally_saved.png" - - mock_single_plot.savefig(str(actual_filepath)) - - mocked_encrypted_s3_bucket.download_file( - BUCKET_NAME, KEY_PATH, str(download_path) - ) - - assert actual_filepath.read_bytes() == download_path.read_bytes() - - @pytest.mark.parametrize( - "fs_args", - [{"open_args_save": {"mode": "w", "compression": "gzip"}}], - indirect=True, - ) - def test_open_extra_args(self, plot_writer, fs_args): - assert plot_writer._fs_open_args_save == fs_args["open_args_save"] - - def test_load_fail(self, plot_writer): - pattern = r"Loading not supported for 'MatplotlibWriter'" - with pytest.raises(DatasetError, match=pattern): - plot_writer.load() - - @pytest.mark.usefixtures("s3fs_cleanup") - def test_exists_single(self, mock_single_plot, plot_writer): - assert not plot_writer.exists() - plot_writer.save(mock_single_plot) - assert plot_writer.exists() - - @pytest.mark.usefixtures("s3fs_cleanup") - def test_exists_multiple(self, mock_dict_plot, plot_writer): - assert not plot_writer.exists() - plot_writer.save(mock_dict_plot) - assert plot_writer.exists() - - def test_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - data_set = MatplotlibWriter(filepath=FULL_PATH) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(f"{BUCKET_NAME}/{KEY_PATH}") - - -class TestMatplotlibWriterVersioned: - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "chart.png" - chart = MatplotlibWriter(filepath=filepath) - chart_versioned = MatplotlibWriter( - filepath=filepath, version=Version(load_version, save_version) - ) - assert filepath in str(chart) - assert "version" not in str(chart) - - assert filepath in str(chart_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(chart_versioned) - - def test_prevent_overwrite(self, mock_single_plot, versioned_plot_writer): - """Check the error when attempting to override the data set if the - corresponding matplotlib file for a given save version already exists.""" - versioned_plot_writer.save(mock_single_plot) - pattern = ( - r"Save path \'.+\' for MatplotlibWriter\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_plot_writer.save(mock_single_plot) - - def test_ineffective_overwrite(self, load_version, save_version): - pattern = ( - "Setting 'overwrite=True' is ineffective if versioning " - "is enabled, since the versioned path must not already " - "exist; overriding flag with 'overwrite=False' instead." - ) - with pytest.warns(UserWarning, match=pattern): - versioned_plot_writer = MatplotlibWriter( - filepath="/tmp/file.txt", - version=Version(load_version, save_version), - overwrite=True, - ) - assert not versioned_plot_writer._overwrite - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, load_version, save_version, mock_single_plot, versioned_plot_writer - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for MatplotlibWriter\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_plot_writer.save(mock_single_plot) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - MatplotlibWriter( - filepath="https://example.com/file.png", version=Version(None, None) - ) - - def test_load_not_supported(self, versioned_plot_writer): - """Check the error if no versions are available for load.""" - pattern = ( - rf"Loading not supported for '{versioned_plot_writer.__class__.__name__}'" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_plot_writer.load() - - def test_exists(self, versioned_plot_writer, mock_single_plot): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_plot_writer.exists() - versioned_plot_writer.save(mock_single_plot) - assert versioned_plot_writer.exists() - - def test_exists_multiple(self, versioned_plot_writer, mock_list_plot): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_plot_writer.exists() - versioned_plot_writer.save(mock_list_plot) - assert versioned_plot_writer.exists() - - def test_save_data(self, versioned_plot_writer, mock_single_plot, tmp_path): - """Test saving dictionary of plots with enabled versioning.""" - versioned_plot_writer.save(mock_single_plot) - - test_path = tmp_path / "test_image.png" - actual_filepath = Path(versioned_plot_writer._get_load_path().as_posix()) - - plt.savefig(str(test_path)) - - assert actual_filepath.read_bytes() == test_path.read_bytes() - - def test_list_save(self, tmp_path, mock_list_plot, versioned_plot_writer): - """Test saving list of plots to with enabled versioning.""" - - versioned_plot_writer.save(mock_list_plot) - - for index in range(5): - - test_path = tmp_path / "test_image.png" - versioned_filepath = str(versioned_plot_writer._get_load_path()) - - mock_list_plot[index].savefig(str(test_path)) - actual_filepath = Path(f"{versioned_filepath}/{index}.png") - - assert actual_filepath.read_bytes() == test_path.read_bytes() - - def test_dict_save(self, tmp_path, mock_dict_plot, versioned_plot_writer): - """Test saving dictionary of plots with enabled versioning.""" - - versioned_plot_writer.save(mock_dict_plot) - - for colour in COLOUR_LIST: - test_path = tmp_path / "test_image.png" - versioned_filepath = str(versioned_plot_writer._get_load_path()) - - mock_dict_plot[colour].savefig(str(test_path)) - actual_filepath = Path(f"{versioned_filepath}/{colour}") - - assert actual_filepath.read_bytes() == test_path.read_bytes() - - def test_versioning_existing_dataset_single_plot( - self, plot_writer, versioned_plot_writer, mock_single_plot - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset, using a single plot.""" - - plot_writer = MatplotlibWriter( - filepath=versioned_plot_writer._filepath.as_posix() - ) - plot_writer.save(mock_single_plot) - assert plot_writer.exists() - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_plot_writer._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_plot_writer.save(mock_single_plot) - - # Remove non-versioned dataset and try again - Path(plot_writer._filepath.as_posix()).unlink() - versioned_plot_writer.save(mock_single_plot) - assert versioned_plot_writer.exists() - - def test_versioning_existing_dataset_list_plot( - self, plot_writer, versioned_plot_writer, mock_list_plot - ): - """Check the behavior when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset, using a list of plots. Note: because - a list of plots saves to a directory, an error is not expected.""" - plot_writer = MatplotlibWriter( - filepath=versioned_plot_writer._filepath.as_posix() - ) - plot_writer.save(mock_list_plot) - assert plot_writer.exists() - versioned_plot_writer.save(mock_list_plot) - assert versioned_plot_writer.exists() - - def test_versioning_existing_dataset_dict_plot( - self, plot_writer, versioned_plot_writer, mock_dict_plot - ): - """Check the behavior when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset, using a dict of plots. Note: because - a dict of plots saves to a directory, an error is not expected.""" - plot_writer = MatplotlibWriter( - filepath=versioned_plot_writer._filepath.as_posix() - ) - plot_writer.save(mock_dict_plot) - assert plot_writer.exists() - versioned_plot_writer.save(mock_dict_plot) - assert versioned_plot_writer.exists() diff --git a/tests/extras/datasets/networkx/__init__.py b/tests/extras/datasets/networkx/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/networkx/test_gml_dataset.py b/tests/extras/datasets/networkx/test_gml_dataset.py deleted file mode 100644 index 88f7b18a77..0000000000 --- a/tests/extras/datasets/networkx/test_gml_dataset.py +++ /dev/null @@ -1,188 +0,0 @@ -from pathlib import Path, PurePosixPath - -import networkx -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.networkx import GMLDataSet -from kedro.io import DatasetError, Version -from kedro.io.core import PROTOCOL_DELIMITER - -ATTRS = { - "source": "from", - "target": "to", - "name": "fake_id", - "key": "fake_key", - "link": "fake_link", -} - - -@pytest.fixture -def filepath_gml(tmp_path): - return (tmp_path / "some_dir" / "test.gml").as_posix() - - -@pytest.fixture -def gml_data_set(filepath_gml): - return GMLDataSet( - filepath=filepath_gml, - load_args={"destringizer": int}, - save_args={"stringizer": str}, - ) - - -@pytest.fixture -def versioned_gml_data_set(filepath_gml, load_version, save_version): - return GMLDataSet( - filepath=filepath_gml, - version=Version(load_version, save_version), - load_args={"destringizer": int}, - save_args={"stringizer": str}, - ) - - -@pytest.fixture() -def dummy_graph_data(): - return networkx.complete_graph(3) - - -class TestGMLDataSet: - def test_save_and_load(self, gml_data_set, dummy_graph_data): - """Test saving and reloading the data set.""" - gml_data_set.save(dummy_graph_data) - reloaded = gml_data_set.load() - assert dummy_graph_data.nodes(data=True) == reloaded.nodes(data=True) - assert gml_data_set._fs_open_args_load == {"mode": "rb"} - assert gml_data_set._fs_open_args_save == {"mode": "wb"} - - def test_load_missing_file(self, gml_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set GMLDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - assert gml_data_set.load() - - def test_exists(self, gml_data_set, dummy_graph_data): - """Test `exists` method invocation.""" - assert not gml_data_set.exists() - gml_data_set.save(dummy_graph_data) - assert gml_data_set.exists() - - @pytest.mark.parametrize( - "filepath,instance_type", - [ - ("s3://bucket/file.gml", S3FileSystem), - ("file:///tmp/test.gml", LocalFileSystem), - ("/tmp/test.gml", LocalFileSystem), - ("gcs://bucket/file.gml", GCSFileSystem), - ("https://example.com/file.gml", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, filepath, instance_type): - data_set = GMLDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.gml" - data_set = GMLDataSet(filepath=filepath) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - -class TestGMLDataSetVersioned: - def test_save_and_load(self, versioned_gml_data_set, dummy_graph_data): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_gml_data_set.save(dummy_graph_data) - reloaded = versioned_gml_data_set.load() - assert dummy_graph_data.nodes(data=True) == reloaded.nodes(data=True) - assert versioned_gml_data_set._fs_open_args_load == {"mode": "rb"} - assert versioned_gml_data_set._fs_open_args_save == {"mode": "wb"} - - def test_no_versions(self, versioned_gml_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for GMLDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_gml_data_set.load() - - def test_exists(self, versioned_gml_data_set, dummy_graph_data): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_gml_data_set.exists() - versioned_gml_data_set.save(dummy_graph_data) - assert versioned_gml_data_set.exists() - - def test_prevent_override(self, versioned_gml_data_set, dummy_graph_data): - """Check the error when attempt to override the same data set - version.""" - versioned_gml_data_set.save(dummy_graph_data) - pattern = ( - r"Save path \'.+\' for GMLDataSet\(.+\) must not " - r"exist if versioning is enabled" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_gml_data_set.save(dummy_graph_data) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_gml_data_set, load_version, save_version, dummy_graph_data - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match " - rf"load version '{load_version}' for GMLDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_gml_data_set.save(dummy_graph_data) - - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "test.gml" - ds = GMLDataSet(filepath=filepath) - ds_versioned = GMLDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - assert filepath in str(ds) - assert "version" not in str(ds) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "GMLDataSet" in str(ds_versioned) - assert "GMLDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - - def test_versioning_existing_dataset( - self, gml_data_set, versioned_gml_data_set, dummy_graph_data - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - gml_data_set.save(dummy_graph_data) - assert gml_data_set.exists() - assert gml_data_set._filepath == versioned_gml_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_gml_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_gml_data_set.save(dummy_graph_data) - - # Remove non-versioned dataset and try again - Path(gml_data_set._filepath.as_posix()).unlink() - versioned_gml_data_set.save(dummy_graph_data) - assert versioned_gml_data_set.exists() diff --git a/tests/extras/datasets/networkx/test_graphml_dataset.py b/tests/extras/datasets/networkx/test_graphml_dataset.py deleted file mode 100644 index 1d744a61cb..0000000000 --- a/tests/extras/datasets/networkx/test_graphml_dataset.py +++ /dev/null @@ -1,188 +0,0 @@ -from pathlib import Path, PurePosixPath - -import networkx -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.networkx import GraphMLDataSet -from kedro.io import DatasetError, Version -from kedro.io.core import PROTOCOL_DELIMITER - -ATTRS = { - "source": "from", - "target": "to", - "name": "fake_id", - "key": "fake_key", - "link": "fake_link", -} - - -@pytest.fixture -def filepath_graphml(tmp_path): - return (tmp_path / "some_dir" / "test.graphml").as_posix() - - -@pytest.fixture -def graphml_data_set(filepath_graphml): - return GraphMLDataSet( - filepath=filepath_graphml, - load_args={"node_type": int}, - save_args={}, - ) - - -@pytest.fixture -def versioned_graphml_data_set(filepath_graphml, load_version, save_version): - return GraphMLDataSet( - filepath=filepath_graphml, - version=Version(load_version, save_version), - load_args={"node_type": int}, - save_args={}, - ) - - -@pytest.fixture() -def dummy_graph_data(): - return networkx.complete_graph(3) - - -class TestGraphMLDataSet: - def test_save_and_load(self, graphml_data_set, dummy_graph_data): - """Test saving and reloading the data set.""" - graphml_data_set.save(dummy_graph_data) - reloaded = graphml_data_set.load() - assert dummy_graph_data.nodes(data=True) == reloaded.nodes(data=True) - assert graphml_data_set._fs_open_args_load == {"mode": "rb"} - assert graphml_data_set._fs_open_args_save == {"mode": "wb"} - - def test_load_missing_file(self, graphml_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set GraphMLDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - assert graphml_data_set.load() - - def test_exists(self, graphml_data_set, dummy_graph_data): - """Test `exists` method invocation.""" - assert not graphml_data_set.exists() - graphml_data_set.save(dummy_graph_data) - assert graphml_data_set.exists() - - @pytest.mark.parametrize( - "filepath,instance_type", - [ - ("s3://bucket/file.graphml", S3FileSystem), - ("file:///tmp/test.graphml", LocalFileSystem), - ("/tmp/test.graphml", LocalFileSystem), - ("gcs://bucket/file.graphml", GCSFileSystem), - ("https://example.com/file.graphml", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, filepath, instance_type): - data_set = GraphMLDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.graphml" - data_set = GraphMLDataSet(filepath=filepath) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - -class TestGraphMLDataSetVersioned: - def test_save_and_load(self, versioned_graphml_data_set, dummy_graph_data): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_graphml_data_set.save(dummy_graph_data) - reloaded = versioned_graphml_data_set.load() - assert dummy_graph_data.nodes(data=True) == reloaded.nodes(data=True) - assert versioned_graphml_data_set._fs_open_args_load == {"mode": "rb"} - assert versioned_graphml_data_set._fs_open_args_save == {"mode": "wb"} - - def test_no_versions(self, versioned_graphml_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for GraphMLDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_graphml_data_set.load() - - def test_exists(self, versioned_graphml_data_set, dummy_graph_data): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_graphml_data_set.exists() - versioned_graphml_data_set.save(dummy_graph_data) - assert versioned_graphml_data_set.exists() - - def test_prevent_override(self, versioned_graphml_data_set, dummy_graph_data): - """Check the error when attempt to override the same data set - version.""" - versioned_graphml_data_set.save(dummy_graph_data) - pattern = ( - r"Save path \'.+\' for GraphMLDataSet\(.+\) must not " - r"exist if versioning is enabled" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_graphml_data_set.save(dummy_graph_data) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_graphml_data_set, load_version, save_version, dummy_graph_data - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match " - rf"load version '{load_version}' for GraphMLDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_graphml_data_set.save(dummy_graph_data) - - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "test.graphml" - ds = GraphMLDataSet(filepath=filepath) - ds_versioned = GraphMLDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - assert filepath in str(ds) - assert "version" not in str(ds) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "GraphMLDataSet" in str(ds_versioned) - assert "GraphMLDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - - def test_versioning_existing_dataset( - self, graphml_data_set, versioned_graphml_data_set, dummy_graph_data - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - graphml_data_set.save(dummy_graph_data) - assert graphml_data_set.exists() - assert graphml_data_set._filepath == versioned_graphml_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_graphml_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_graphml_data_set.save(dummy_graph_data) - - # Remove non-versioned dataset and try again - Path(graphml_data_set._filepath.as_posix()).unlink() - versioned_graphml_data_set.save(dummy_graph_data) - assert versioned_graphml_data_set.exists() diff --git a/tests/extras/datasets/networkx/test_json_dataset.py b/tests/extras/datasets/networkx/test_json_dataset.py deleted file mode 100644 index 55c7ebd213..0000000000 --- a/tests/extras/datasets/networkx/test_json_dataset.py +++ /dev/null @@ -1,226 +0,0 @@ -from pathlib import Path, PurePosixPath - -import networkx -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.networkx import JSONDataSet -from kedro.io import DatasetError, Version -from kedro.io.core import PROTOCOL_DELIMITER - -ATTRS = { - "source": "from", - "target": "to", - "name": "fake_id", - "key": "fake_key", - "link": "fake_link", -} - - -@pytest.fixture -def filepath_json(tmp_path): - return (tmp_path / "some_dir" / "test.json").as_posix() - - -@pytest.fixture -def json_data_set(filepath_json, fs_args): - return JSONDataSet(filepath=filepath_json, fs_args=fs_args) - - -@pytest.fixture -def versioned_json_data_set(filepath_json, load_version, save_version): - return JSONDataSet( - filepath=filepath_json, version=Version(load_version, save_version) - ) - - -@pytest.fixture -def json_data_set_args(filepath_json): - return JSONDataSet( - filepath=filepath_json, load_args={"attrs": ATTRS}, save_args={"attrs": ATTRS} - ) - - -@pytest.fixture() -def dummy_graph_data(): - return networkx.complete_graph(3) - - -class TestJSONDataSet: - def test_save_and_load(self, json_data_set, dummy_graph_data): - """Test saving and reloading the data set.""" - json_data_set.save(dummy_graph_data) - reloaded = json_data_set.load() - assert dummy_graph_data.nodes(data=True) == reloaded.nodes(data=True) - assert json_data_set._fs_open_args_load == {} - assert json_data_set._fs_open_args_save == {"mode": "w"} - - def test_load_missing_file(self, json_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set JSONDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - assert json_data_set.load() - - def test_load_args_save_args(self, mocker, json_data_set_args, dummy_graph_data): - """Test saving and reloading with save and load arguments.""" - patched_save = mocker.patch( - "networkx.node_link_data", wraps=networkx.node_link_data - ) - json_data_set_args.save(dummy_graph_data) - patched_save.assert_called_once_with(dummy_graph_data, attrs=ATTRS) - - patched_load = mocker.patch( - "networkx.node_link_graph", wraps=networkx.node_link_graph - ) - # load args need to be the same attrs as the ones used for saving - # in order to successfully retrieve data - reloaded = json_data_set_args.load() - - patched_load.assert_called_once_with( - { - "directed": False, - "multigraph": False, - "graph": {}, - "nodes": [{"fake_id": 0}, {"fake_id": 1}, {"fake_id": 2}], - "fake_link": [ - {"from": 0, "to": 1}, - {"from": 0, "to": 2}, - {"from": 1, "to": 2}, - ], - }, - attrs=ATTRS, - ) - assert dummy_graph_data.nodes(data=True) == reloaded.nodes(data=True) - - @pytest.mark.parametrize( - "fs_args", - [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], - indirect=True, - ) - def test_open_extra_args(self, json_data_set, fs_args): - assert json_data_set._fs_open_args_load == fs_args["open_args_load"] - assert json_data_set._fs_open_args_save == {"mode": "w"} # default unchanged - - def test_exists(self, json_data_set, dummy_graph_data): - """Test `exists` method invocation.""" - assert not json_data_set.exists() - json_data_set.save(dummy_graph_data) - assert json_data_set.exists() - - @pytest.mark.parametrize( - "filepath,instance_type", - [ - ("s3://bucket/file.json", S3FileSystem), - ("file:///tmp/test.json", LocalFileSystem), - ("/tmp/test.json", LocalFileSystem), - ("gcs://bucket/file.json", GCSFileSystem), - ("https://example.com/file.json", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, filepath, instance_type): - data_set = JSONDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.json" - data_set = JSONDataSet(filepath=filepath) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - -class TestJSONDataSetVersioned: - def test_save_and_load(self, versioned_json_data_set, dummy_graph_data): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_json_data_set.save(dummy_graph_data) - reloaded = versioned_json_data_set.load() - assert dummy_graph_data.nodes(data=True) == reloaded.nodes(data=True) - - def test_no_versions(self, versioned_json_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for JSONDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_json_data_set.load() - - def test_exists(self, versioned_json_data_set, dummy_graph_data): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_json_data_set.exists() - versioned_json_data_set.save(dummy_graph_data) - assert versioned_json_data_set.exists() - - def test_prevent_override(self, versioned_json_data_set, dummy_graph_data): - """Check the error when attempt to override the same data set - version.""" - versioned_json_data_set.save(dummy_graph_data) - pattern = ( - r"Save path \'.+\' for JSONDataSet\(.+\) must not " - r"exist if versioning is enabled" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_json_data_set.save(dummy_graph_data) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_json_data_set, load_version, save_version, dummy_graph_data - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for JSONDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_json_data_set.save(dummy_graph_data) - - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "test.json" - ds = JSONDataSet(filepath=filepath) - ds_versioned = JSONDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - assert filepath in str(ds) - assert "version" not in str(ds) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "JSONDataSet" in str(ds_versioned) - assert "JSONDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - - def test_versioning_existing_dataset( - self, json_data_set, versioned_json_data_set, dummy_graph_data - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - json_data_set.save(dummy_graph_data) - assert json_data_set.exists() - assert json_data_set._filepath == versioned_json_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_json_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_json_data_set.save(dummy_graph_data) - - # Remove non-versioned dataset and try again - Path(json_data_set._filepath.as_posix()).unlink() - versioned_json_data_set.save(dummy_graph_data) - assert versioned_json_data_set.exists() diff --git a/tests/extras/datasets/pandas/__init__.py b/tests/extras/datasets/pandas/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/pandas/test_csv_dataset.py b/tests/extras/datasets/pandas/test_csv_dataset.py deleted file mode 100644 index a2a15f5938..0000000000 --- a/tests/extras/datasets/pandas/test_csv_dataset.py +++ /dev/null @@ -1,300 +0,0 @@ -from pathlib import Path, PurePosixPath -from time import sleep - -import pandas as pd -import pytest -from adlfs import AzureBlobFileSystem -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from pandas.testing import assert_frame_equal -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.pandas import CSVDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version, generate_timestamp - - -@pytest.fixture -def filepath_csv(tmp_path): - return (tmp_path / "test.csv").as_posix() - - -@pytest.fixture -def csv_data_set(filepath_csv, load_args, save_args, fs_args): - return CSVDataSet( - filepath=filepath_csv, load_args=load_args, save_args=save_args, fs_args=fs_args - ) - - -@pytest.fixture -def versioned_csv_data_set(filepath_csv, load_version, save_version): - return CSVDataSet( - filepath=filepath_csv, version=Version(load_version, save_version) - ) - - -@pytest.fixture -def dummy_dataframe(): - return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) - - -class TestCSVDataSet: - def test_save_and_load(self, csv_data_set, dummy_dataframe): - """Test saving and reloading the data set.""" - csv_data_set.save(dummy_dataframe) - reloaded = csv_data_set.load() - assert_frame_equal(dummy_dataframe, reloaded) - - def test_exists(self, csv_data_set, dummy_dataframe): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not csv_data_set.exists() - csv_data_set.save(dummy_dataframe) - assert csv_data_set.exists() - - @pytest.mark.parametrize( - "load_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_load_extra_params(self, csv_data_set, load_args): - """Test overriding the default load arguments.""" - for key, value in load_args.items(): - assert csv_data_set._load_args[key] == value - - @pytest.mark.parametrize( - "save_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_save_extra_params(self, csv_data_set, save_args): - """Test overriding the default save arguments.""" - for key, value in save_args.items(): - assert csv_data_set._save_args[key] == value - - @pytest.mark.parametrize( - "load_args,save_args", - [ - ({"storage_options": {"a": "b"}}, {}), - ({}, {"storage_options": {"a": "b"}}), - ({"storage_options": {"a": "b"}}, {"storage_options": {"x": "y"}}), - ], - ) - def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): - filepath = str(tmp_path / "test.csv") - - ds = CSVDataSet(filepath=filepath, load_args=load_args, save_args=save_args) - - records = [r for r in caplog.records if r.levelname == "WARNING"] - expected_log_message = ( - f"Dropping 'storage_options' for {filepath}, " - f"please specify them under 'fs_args' or 'credentials'." - ) - assert records[0].getMessage() == expected_log_message - assert "storage_options" not in ds._save_args - assert "storage_options" not in ds._load_args - - def test_load_missing_file(self, csv_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set CSVDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - csv_data_set.load() - - @pytest.mark.parametrize( - "filepath,instance_type,credentials", - [ - ("s3://bucket/file.csv", S3FileSystem, {}), - ("file:///tmp/test.csv", LocalFileSystem, {}), - ("/tmp/test.csv", LocalFileSystem, {}), - ("gcs://bucket/file.csv", GCSFileSystem, {}), - ("https://example.com/file.csv", HTTPFileSystem, {}), - ( - "abfs://bucket/file.csv", - AzureBlobFileSystem, - {"account_name": "test", "account_key": "test"}, - ), - ], - ) - def test_protocol_usage(self, filepath, instance_type, credentials): - data_set = CSVDataSet(filepath=filepath, credentials=credentials) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.csv" - data_set = CSVDataSet(filepath=filepath) - assert data_set._version_cache.currsize == 0 # no cache if unversioned - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - assert data_set._version_cache.currsize == 0 - - -class TestCSVDataSetVersioned: - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "test.csv" - ds = CSVDataSet(filepath=filepath) - ds_versioned = CSVDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - assert filepath in str(ds) - assert "version" not in str(ds) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "CSVDataSet" in str(ds_versioned) - assert "CSVDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - # Default save_args - assert "save_args={'index': False}" in str(ds) - assert "save_args={'index': False}" in str(ds_versioned) - - def test_save_and_load(self, versioned_csv_data_set, dummy_dataframe): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_csv_data_set.save(dummy_dataframe) - reloaded_df = versioned_csv_data_set.load() - assert_frame_equal(dummy_dataframe, reloaded_df) - - def test_multiple_loads( - self, versioned_csv_data_set, dummy_dataframe, filepath_csv - ): - """Test that if a new version is created mid-run, by an - external system, it won't be loaded in the current run.""" - versioned_csv_data_set.save(dummy_dataframe) - versioned_csv_data_set.load() - v1 = versioned_csv_data_set.resolve_load_version() - - sleep(0.5) - # force-drop a newer version into the same location - v_new = generate_timestamp() - CSVDataSet(filepath=filepath_csv, version=Version(v_new, v_new)).save( - dummy_dataframe - ) - - versioned_csv_data_set.load() - v2 = versioned_csv_data_set.resolve_load_version() - - assert v2 == v1 # v2 should not be v_new! - ds_new = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) - assert ( - ds_new.resolve_load_version() == v_new - ) # new version is discoverable by a new instance - - def test_multiple_saves(self, dummy_dataframe, filepath_csv): - """Test multiple cycles of save followed by load for the same dataset""" - ds_versioned = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) - - # first save - ds_versioned.save(dummy_dataframe) - first_save_version = ds_versioned.resolve_save_version() - first_load_version = ds_versioned.resolve_load_version() - assert first_load_version == first_save_version - - # second save - sleep(0.5) - ds_versioned.save(dummy_dataframe) - second_save_version = ds_versioned.resolve_save_version() - second_load_version = ds_versioned.resolve_load_version() - assert second_load_version == second_save_version - assert second_load_version > first_load_version - - # another dataset - ds_new = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) - assert ds_new.resolve_load_version() == second_load_version - - def test_release_instance_cache(self, dummy_dataframe, filepath_csv): - """Test that cache invalidation does not affect other instances""" - ds_a = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) - assert ds_a._version_cache.currsize == 0 - ds_a.save(dummy_dataframe) # create a version - assert ds_a._version_cache.currsize == 2 - - ds_b = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) - assert ds_b._version_cache.currsize == 0 - ds_b.resolve_save_version() - assert ds_b._version_cache.currsize == 1 - ds_b.resolve_load_version() - assert ds_b._version_cache.currsize == 2 - - ds_a.release() - - # dataset A cache is cleared - assert ds_a._version_cache.currsize == 0 - - # dataset B cache is unaffected - assert ds_b._version_cache.currsize == 2 - - def test_no_versions(self, versioned_csv_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for CSVDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_csv_data_set.load() - - def test_exists(self, versioned_csv_data_set, dummy_dataframe): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_csv_data_set.exists() - versioned_csv_data_set.save(dummy_dataframe) - assert versioned_csv_data_set.exists() - - def test_prevent_overwrite(self, versioned_csv_data_set, dummy_dataframe): - """Check the error when attempting to override the data set if the - corresponding CSV file for a given save version already exists.""" - versioned_csv_data_set.save(dummy_dataframe) - pattern = ( - r"Save path \'.+\' for CSVDataSet\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_csv_data_set, load_version, save_version, dummy_dataframe - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for CSVDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - CSVDataSet( - filepath="https://example.com/file.csv", version=Version(None, None) - ) - - def test_versioning_existing_dataset( - self, csv_data_set, versioned_csv_data_set, dummy_dataframe - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - csv_data_set.save(dummy_dataframe) - assert csv_data_set.exists() - assert csv_data_set._filepath == versioned_csv_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_csv_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) - - # Remove non-versioned dataset and try again - Path(csv_data_set._filepath.as_posix()).unlink() - versioned_csv_data_set.save(dummy_dataframe) - assert versioned_csv_data_set.exists() diff --git a/tests/extras/datasets/pandas/test_excel_dataset.py b/tests/extras/datasets/pandas/test_excel_dataset.py deleted file mode 100644 index d558d3b22f..0000000000 --- a/tests/extras/datasets/pandas/test_excel_dataset.py +++ /dev/null @@ -1,281 +0,0 @@ -from pathlib import Path, PurePosixPath - -import pandas as pd -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from pandas.testing import assert_frame_equal -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.pandas import ExcelDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version - - -@pytest.fixture -def filepath_excel(tmp_path): - return (tmp_path / "test.xlsx").as_posix() - - -@pytest.fixture -def excel_data_set(filepath_excel, load_args, save_args, fs_args): - return ExcelDataSet( - filepath=filepath_excel, - load_args=load_args, - save_args=save_args, - fs_args=fs_args, - ) - - -@pytest.fixture -def excel_multisheet_data_set(filepath_excel, save_args, fs_args): - load_args = {"sheet_name": None} - return ExcelDataSet( - filepath=filepath_excel, - load_args=load_args, - save_args=save_args, - fs_args=fs_args, - ) - - -@pytest.fixture -def versioned_excel_data_set(filepath_excel, load_version, save_version): - return ExcelDataSet( - filepath=filepath_excel, version=Version(load_version, save_version) - ) - - -@pytest.fixture -def dummy_dataframe(): - return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) - - -@pytest.fixture -def another_dummy_dataframe(): - return pd.DataFrame({"x": [10, 20], "y": ["hello", "world"]}) - - -class TestExcelDataSet: - def test_save_and_load(self, excel_data_set, dummy_dataframe): - """Test saving and reloading the data set.""" - excel_data_set.save(dummy_dataframe) - reloaded = excel_data_set.load() - assert_frame_equal(dummy_dataframe, reloaded) - - def test_save_and_load_multiple_sheets( - self, excel_multisheet_data_set, dummy_dataframe, another_dummy_dataframe - ): - """Test saving and reloading the data set with multiple sheets.""" - dummy_multisheet = { - "sheet 1": dummy_dataframe, - "sheet 2": another_dummy_dataframe, - } - excel_multisheet_data_set.save(dummy_multisheet) - reloaded = excel_multisheet_data_set.load() - assert_frame_equal(dummy_multisheet["sheet 1"], reloaded["sheet 1"]) - assert_frame_equal(dummy_multisheet["sheet 2"], reloaded["sheet 2"]) - - def test_exists(self, excel_data_set, dummy_dataframe): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not excel_data_set.exists() - excel_data_set.save(dummy_dataframe) - assert excel_data_set.exists() - - @pytest.mark.parametrize( - "load_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_load_extra_params(self, excel_data_set, load_args): - """Test overriding the default load arguments.""" - for key, value in load_args.items(): - assert excel_data_set._load_args[key] == value - - @pytest.mark.parametrize( - "save_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_save_extra_params(self, excel_data_set, save_args): - """Test overriding the default save arguments.""" - for key, value in save_args.items(): - assert excel_data_set._save_args[key] == value - - @pytest.mark.parametrize( - "load_args,save_args", - [ - ({"storage_options": {"a": "b"}}, {}), - ({}, {"storage_options": {"a": "b"}}), - ({"storage_options": {"a": "b"}}, {"storage_options": {"x": "y"}}), - ], - ) - def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): - filepath = str(tmp_path / "test.csv") - - ds = ExcelDataSet(filepath=filepath, load_args=load_args, save_args=save_args) - - records = [r for r in caplog.records if r.levelname == "WARNING"] - expected_log_message = ( - f"Dropping 'storage_options' for {filepath}, " - f"please specify them under 'fs_args' or 'credentials'." - ) - assert records[0].getMessage() == expected_log_message - assert "storage_options" not in ds._save_args - assert "storage_options" not in ds._load_args - - def test_load_missing_file(self, excel_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set ExcelDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - excel_data_set.load() - - @pytest.mark.parametrize( - "filepath,instance_type,load_path", - [ - ("s3://bucket/file.xlsx", S3FileSystem, "s3://bucket/file.xlsx"), - ("file:///tmp/test.xlsx", LocalFileSystem, "/tmp/test.xlsx"), - ("/tmp/test.xlsx", LocalFileSystem, "/tmp/test.xlsx"), - ("gcs://bucket/file.xlsx", GCSFileSystem, "gcs://bucket/file.xlsx"), - ( - "https://example.com/file.xlsx", - HTTPFileSystem, - "https://example.com/file.xlsx", - ), - ], - ) - def test_protocol_usage(self, filepath, instance_type, load_path, mocker): - data_set = ExcelDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - mock_pandas_call = mocker.patch("pandas.read_excel") - data_set.load() - assert mock_pandas_call.call_count == 1 - assert mock_pandas_call.call_args_list[0][0][0] == load_path - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.xlsx" - data_set = ExcelDataSet(filepath=filepath) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - -class TestExcelDataSetVersioned: - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "test.xlsx" - ds = ExcelDataSet(filepath=filepath) - ds_versioned = ExcelDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - assert filepath in str(ds) - assert "version" not in str(ds) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "ExcelDataSet" in str(ds_versioned) - assert "ExcelDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - assert "writer_args" in str(ds_versioned) - assert "writer_args" in str(ds) - # Default save_args and load_args - assert "save_args={'index': False}" in str(ds) - assert "save_args={'index': False}" in str(ds_versioned) - assert "load_args={'engine': openpyxl}" in str(ds_versioned) - assert "load_args={'engine': openpyxl}" in str(ds) - - def test_save_and_load(self, versioned_excel_data_set, dummy_dataframe): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_excel_data_set.save(dummy_dataframe) - reloaded_df = versioned_excel_data_set.load() - assert_frame_equal(dummy_dataframe, reloaded_df) - - def test_no_versions(self, versioned_excel_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for ExcelDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_excel_data_set.load() - - def test_versioning_not_supported_in_append_mode( - self, tmp_path, load_version, save_version - ): - filepath = str(tmp_path / "test.xlsx") - save_args = {"writer": {"mode": "a"}} - - pattern = "'ExcelDataSet' doesn't support versioning in append mode." - with pytest.raises(DatasetError, match=pattern): - ExcelDataSet( - filepath=filepath, - version=Version(load_version, save_version), - save_args=save_args, - ) - - def test_exists(self, versioned_excel_data_set, dummy_dataframe): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_excel_data_set.exists() - versioned_excel_data_set.save(dummy_dataframe) - assert versioned_excel_data_set.exists() - - def test_prevent_overwrite(self, versioned_excel_data_set, dummy_dataframe): - """Check the error when attempting to override the data set if the - corresponding Excel file for a given save version already exists.""" - versioned_excel_data_set.save(dummy_dataframe) - pattern = ( - r"Save path \'.+\' for ExcelDataSet\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_excel_data_set.save(dummy_dataframe) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_excel_data_set, load_version, save_version, dummy_dataframe - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for ExcelDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_excel_data_set.save(dummy_dataframe) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - ExcelDataSet( - filepath="https://example.com/file.xlsx", version=Version(None, None) - ) - - def test_versioning_existing_dataset( - self, excel_data_set, versioned_excel_data_set, dummy_dataframe - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - excel_data_set.save(dummy_dataframe) - assert excel_data_set.exists() - assert excel_data_set._filepath == versioned_excel_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_excel_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_excel_data_set.save(dummy_dataframe) - - # Remove non-versioned dataset and try again - Path(excel_data_set._filepath.as_posix()).unlink() - versioned_excel_data_set.save(dummy_dataframe) - assert versioned_excel_data_set.exists() diff --git a/tests/extras/datasets/pandas/test_feather_dataset.py b/tests/extras/datasets/pandas/test_feather_dataset.py deleted file mode 100644 index 8637bd2bcf..0000000000 --- a/tests/extras/datasets/pandas/test_feather_dataset.py +++ /dev/null @@ -1,220 +0,0 @@ -from pathlib import Path, PurePosixPath - -import pandas as pd -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from pandas.testing import assert_frame_equal -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.pandas import FeatherDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version - - -@pytest.fixture -def filepath_feather(tmp_path): - return (tmp_path / "test.feather").as_posix() - - -@pytest.fixture -def feather_data_set(filepath_feather, load_args, fs_args): - return FeatherDataSet( - filepath=filepath_feather, load_args=load_args, fs_args=fs_args - ) - - -@pytest.fixture -def versioned_feather_data_set(filepath_feather, load_version, save_version): - return FeatherDataSet( - filepath=filepath_feather, version=Version(load_version, save_version) - ) - - -@pytest.fixture -def dummy_dataframe(): - return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) - - -class TestFeatherDataSet: - def test_save_and_load(self, feather_data_set, dummy_dataframe): - """Test saving and reloading the data set.""" - feather_data_set.save(dummy_dataframe) - reloaded = feather_data_set.load() - assert_frame_equal(dummy_dataframe, reloaded) - - def test_exists(self, feather_data_set, dummy_dataframe): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not feather_data_set.exists() - feather_data_set.save(dummy_dataframe) - assert feather_data_set.exists() - - @pytest.mark.parametrize( - "load_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_load_extra_params(self, feather_data_set, load_args): - """Test overriding the default load arguments.""" - for key, value in load_args.items(): - assert feather_data_set._load_args[key] == value - - @pytest.mark.parametrize( - "load_args,save_args", - [ - ({"storage_options": {"a": "b"}}, {}), - ({}, {"storage_options": {"a": "b"}}), - ({"storage_options": {"a": "b"}}, {"storage_options": {"x": "y"}}), - ], - ) - def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): - filepath = str(tmp_path / "test.csv") - - ds = FeatherDataSet(filepath=filepath, load_args=load_args, save_args=save_args) - - records = [r for r in caplog.records if r.levelname == "WARNING"] - expected_log_message = ( - f"Dropping 'storage_options' for {filepath}, " - f"please specify them under 'fs_args' or 'credentials'." - ) - assert records[0].getMessage() == expected_log_message - assert "storage_options" not in ds._save_args - assert "storage_options" not in ds._load_args - - def test_load_missing_file(self, feather_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set FeatherDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - feather_data_set.load() - - @pytest.mark.parametrize( - "filepath,instance_type,load_path", - [ - ("s3://bucket/file.feather", S3FileSystem, "s3://bucket/file.feather"), - ("file:///tmp/test.feather", LocalFileSystem, "/tmp/test.feather"), - ("/tmp/test.feather", LocalFileSystem, "/tmp/test.feather"), - ("gcs://bucket/file.feather", GCSFileSystem, "gcs://bucket/file.feather"), - ( - "https://example.com/file.feather", - HTTPFileSystem, - "https://example.com/file.feather", - ), - ], - ) - def test_protocol_usage(self, filepath, instance_type, load_path, mocker): - data_set = FeatherDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - mock_pandas_call = mocker.patch("pandas.read_feather") - data_set.load() - assert mock_pandas_call.call_count == 1 - assert mock_pandas_call.call_args_list[0][0][0] == load_path - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.feather" - data_set = FeatherDataSet(filepath=filepath) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - -class TestFeatherDataSetVersioned: - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "test.feather" - ds = FeatherDataSet(filepath=filepath) - ds_versioned = FeatherDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - assert filepath in str(ds) - assert "version" not in str(ds) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "FeatherDataSet" in str(ds_versioned) - assert "FeatherDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - - def test_save_and_load(self, versioned_feather_data_set, dummy_dataframe): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_feather_data_set.save(dummy_dataframe) - reloaded_df = versioned_feather_data_set.load() - assert_frame_equal(dummy_dataframe, reloaded_df) - - def test_no_versions(self, versioned_feather_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for FeatherDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_feather_data_set.load() - - def test_exists(self, versioned_feather_data_set, dummy_dataframe): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_feather_data_set.exists() - versioned_feather_data_set.save(dummy_dataframe) - assert versioned_feather_data_set.exists() - - def test_prevent_overwrite(self, versioned_feather_data_set, dummy_dataframe): - """Check the error when attempting to overwrite the data set if the - corresponding feather file for a given save version already exists.""" - versioned_feather_data_set.save(dummy_dataframe) - pattern = ( - r"Save path \'.+\' for FeatherDataSet\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_feather_data_set.save(dummy_dataframe) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_feather_data_set, load_version, save_version, dummy_dataframe - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for FeatherDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_feather_data_set.save(dummy_dataframe) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - FeatherDataSet( - filepath="https://example.com/file.feather", version=Version(None, None) - ) - - def test_versioning_existing_dataset( - self, feather_data_set, versioned_feather_data_set, dummy_dataframe - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - feather_data_set.save(dummy_dataframe) - assert feather_data_set.exists() - assert feather_data_set._filepath == versioned_feather_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_feather_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_feather_data_set.save(dummy_dataframe) - - # Remove non-versioned dataset and try again - Path(feather_data_set._filepath.as_posix()).unlink() - versioned_feather_data_set.save(dummy_dataframe) - assert versioned_feather_data_set.exists() diff --git a/tests/extras/datasets/pandas/test_gbq_dataset.py b/tests/extras/datasets/pandas/test_gbq_dataset.py deleted file mode 100644 index 475f25c93b..0000000000 --- a/tests/extras/datasets/pandas/test_gbq_dataset.py +++ /dev/null @@ -1,315 +0,0 @@ -from pathlib import PosixPath - -import pandas as pd -import pytest -from google.cloud.exceptions import NotFound -from pandas.testing import assert_frame_equal - -from kedro.extras.datasets.pandas import GBQQueryDataSet, GBQTableDataSet -from kedro.io.core import DatasetError - -DATASET = "dataset" -TABLE_NAME = "table_name" -PROJECT = "project" -SQL_QUERY = "SELECT * FROM table_a" - - -@pytest.fixture -def dummy_dataframe(): - return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) - - -@pytest.fixture -def mock_bigquery_client(mocker): - mocked = mocker.patch("google.cloud.bigquery.Client", autospec=True) - return mocked - - -@pytest.fixture -def gbq_dataset( - load_args, save_args, mock_bigquery_client -): # pylint: disable=unused-argument - return GBQTableDataSet( - dataset=DATASET, - table_name=TABLE_NAME, - project=PROJECT, - credentials=None, - load_args=load_args, - save_args=save_args, - ) - - -@pytest.fixture(params=[{}]) -def gbq_sql_dataset(load_args, mock_bigquery_client): # pylint: disable=unused-argument - return GBQQueryDataSet( - sql=SQL_QUERY, - project=PROJECT, - credentials=None, - load_args=load_args, - ) - - -@pytest.fixture -def sql_file(tmp_path: PosixPath): - file = tmp_path / "test.sql" - file.write_text(SQL_QUERY) - return file.as_posix() - - -@pytest.fixture(params=[{}]) -def gbq_sql_file_dataset( - load_args, sql_file, mock_bigquery_client -): # pylint: disable=unused-argument - return GBQQueryDataSet( - filepath=sql_file, - project=PROJECT, - credentials=None, - load_args=load_args, - ) - - -class TestGBQDataSet: - def test_exists(self, mock_bigquery_client): - """Test `exists` method invocation.""" - mock_bigquery_client.return_value.get_table.side_effect = [ - NotFound("NotFound"), - "exists", - ] - - data_set = GBQTableDataSet(DATASET, TABLE_NAME) - assert not data_set.exists() - assert data_set.exists() - - @pytest.mark.parametrize( - "load_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_load_extra_params(self, gbq_dataset, load_args): - """Test overriding the default load arguments.""" - for key, value in load_args.items(): - assert gbq_dataset._load_args[key] == value - - @pytest.mark.parametrize( - "save_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_save_extra_params(self, gbq_dataset, save_args): - """Test overriding the default save arguments.""" - for key, value in save_args.items(): - assert gbq_dataset._save_args[key] == value - - def test_load_missing_file(self, gbq_dataset, mocker): - """Check the error when trying to load missing table.""" - pattern = r"Failed while loading data from data set GBQTableDataSet\(.*\)" - mocked_read_gbq = mocker.patch( - "kedro.extras.datasets.pandas.gbq_dataset.pd.read_gbq" - ) - mocked_read_gbq.side_effect = ValueError - with pytest.raises(DatasetError, match=pattern): - gbq_dataset.load() - - @pytest.mark.parametrize("load_args", [{"location": "l1"}], indirect=True) - @pytest.mark.parametrize("save_args", [{"location": "l2"}], indirect=True) - def test_invalid_location(self, save_args, load_args): - """Check the error when initializing instance if save_args and load_args - 'location' are different.""" - pattern = r""""load_args\['location'\]" is different from "save_args\['location'\]".""" - with pytest.raises(DatasetError, match=pattern): - GBQTableDataSet( - dataset=DATASET, - table_name=TABLE_NAME, - project=PROJECT, - credentials=None, - load_args=load_args, - save_args=save_args, - ) - - @pytest.mark.parametrize("save_args", [{"option1": "value1"}], indirect=True) - @pytest.mark.parametrize("load_args", [{"option2": "value2"}], indirect=True) - def test_str_representation(self, gbq_dataset, save_args, load_args): - """Test string representation of the data set instance.""" - str_repr = str(gbq_dataset) - assert "GBQTableDataSet" in str_repr - assert TABLE_NAME in str_repr - assert DATASET in str_repr - for k in save_args.keys(): - assert k in str_repr - for k in load_args.keys(): - assert k in str_repr - - def test_save_load_data(self, gbq_dataset, dummy_dataframe, mocker): - """Test saving and reloading the data set.""" - sql = f"select * from {DATASET}.{TABLE_NAME}" - table_id = f"{DATASET}.{TABLE_NAME}" - mocked_read_gbq = mocker.patch( - "kedro.extras.datasets.pandas.gbq_dataset.pd.read_gbq" - ) - mocked_read_gbq.return_value = dummy_dataframe - mocked_df = mocker.Mock() - - gbq_dataset.save(mocked_df) - loaded_data = gbq_dataset.load() - - mocked_df.to_gbq.assert_called_once_with( - table_id, project_id=PROJECT, credentials=None, progress_bar=False - ) - mocked_read_gbq.assert_called_once_with( - project_id=PROJECT, credentials=None, query=sql - ) - assert_frame_equal(dummy_dataframe, loaded_data) - - @pytest.mark.parametrize("load_args", [{"query": "Select 1"}], indirect=True) - def test_read_gbq_with_query(self, gbq_dataset, dummy_dataframe, mocker, load_args): - """Test loading data set with query in the argument.""" - mocked_read_gbq = mocker.patch( - "kedro.extras.datasets.pandas.gbq_dataset.pd.read_gbq" - ) - mocked_read_gbq.return_value = dummy_dataframe - loaded_data = gbq_dataset.load() - - mocked_read_gbq.assert_called_once_with( - project_id=PROJECT, credentials=None, query=load_args["query"] - ) - - assert_frame_equal(dummy_dataframe, loaded_data) - - @pytest.mark.parametrize( - "dataset,table_name", - [ - ("data set", TABLE_NAME), - ("data;set", TABLE_NAME), - (DATASET, "table name"), - (DATASET, "table;name"), - ], - ) - def test_validation_of_dataset_and_table_name(self, dataset, table_name): - pattern = "Neither white-space nor semicolon are allowed.*" - with pytest.raises(DatasetError, match=pattern): - GBQTableDataSet(dataset=dataset, table_name=table_name) - - def test_credentials_propagation(self, mocker): - credentials = {"token": "my_token"} - credentials_obj = "credentials" - mocked_credentials = mocker.patch( - "kedro.extras.datasets.pandas.gbq_dataset.Credentials", - return_value=credentials_obj, - ) - mocked_bigquery = mocker.patch( - "kedro.extras.datasets.pandas.gbq_dataset.bigquery" - ) - - data_set = GBQTableDataSet( - dataset=DATASET, - table_name=TABLE_NAME, - credentials=credentials, - project=PROJECT, - ) - - assert data_set._credentials == credentials_obj - mocked_credentials.assert_called_once_with(**credentials) - mocked_bigquery.Client.assert_called_once_with( - project=PROJECT, credentials=credentials_obj, location=None - ) - - -class TestGBQQueryDataSet: - def test_empty_query_error(self): - """Check the error when instantiating with empty query or file""" - pattern = ( - r"'sql' and 'filepath' arguments cannot both be empty\." - r"Please provide a sql query or path to a sql query file\." - ) - with pytest.raises(DatasetError, match=pattern): - GBQQueryDataSet(sql="", filepath="", credentials=None) - - @pytest.mark.parametrize( - "load_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_load_extra_params(self, gbq_sql_dataset, load_args): - """Test overriding the default load arguments.""" - for key, value in load_args.items(): - assert gbq_sql_dataset._load_args[key] == value - - def test_credentials_propagation(self, mocker): - credentials = {"token": "my_token"} - credentials_obj = "credentials" - mocked_credentials = mocker.patch( - "kedro.extras.datasets.pandas.gbq_dataset.Credentials", - return_value=credentials_obj, - ) - mocked_bigquery = mocker.patch( - "kedro.extras.datasets.pandas.gbq_dataset.bigquery" - ) - - data_set = GBQQueryDataSet( - sql=SQL_QUERY, - credentials=credentials, - project=PROJECT, - ) - - assert data_set._credentials == credentials_obj - mocked_credentials.assert_called_once_with(**credentials) - mocked_bigquery.Client.assert_called_once_with( - project=PROJECT, credentials=credentials_obj, location=None - ) - - def test_load(self, mocker, gbq_sql_dataset, dummy_dataframe): - """Test `load` method invocation""" - mocked_read_gbq = mocker.patch( - "kedro.extras.datasets.pandas.gbq_dataset.pd.read_gbq" - ) - mocked_read_gbq.return_value = dummy_dataframe - - loaded_data = gbq_sql_dataset.load() - - mocked_read_gbq.assert_called_once_with( - project_id=PROJECT, credentials=None, query=SQL_QUERY - ) - - assert_frame_equal(dummy_dataframe, loaded_data) - - def test_load_query_file(self, mocker, gbq_sql_file_dataset, dummy_dataframe): - """Test `load` method invocation using a file as input query""" - mocked_read_gbq = mocker.patch( - "kedro.extras.datasets.pandas.gbq_dataset.pd.read_gbq" - ) - mocked_read_gbq.return_value = dummy_dataframe - - loaded_data = gbq_sql_file_dataset.load() - - mocked_read_gbq.assert_called_once_with( - project_id=PROJECT, credentials=None, query=SQL_QUERY - ) - - assert_frame_equal(dummy_dataframe, loaded_data) - - def test_save_error(self, gbq_sql_dataset, dummy_dataframe): - """Check the error when trying to save to the data set""" - pattern = r"'save' is not supported on GBQQueryDataSet" - with pytest.raises(DatasetError, match=pattern): - gbq_sql_dataset.save(dummy_dataframe) - - def test_str_representation_sql(self, gbq_sql_dataset, sql_file): - """Test the data set instance string representation""" - str_repr = str(gbq_sql_dataset) - assert ( - f"GBQQueryDataSet(filepath=None, load_args={{}}, sql={SQL_QUERY})" - in str_repr - ) - assert sql_file not in str_repr - - def test_str_representation_filepath(self, gbq_sql_file_dataset, sql_file): - """Test the data set instance string representation with filepath arg.""" - str_repr = str(gbq_sql_file_dataset) - assert ( - f"GBQQueryDataSet(filepath={str(sql_file)}, load_args={{}}, sql=None)" - in str_repr - ) - assert SQL_QUERY not in str_repr - - def test_sql_and_filepath_args(self, sql_file): - """Test that an error is raised when both `sql` and `filepath` args are given.""" - pattern = ( - r"'sql' and 'filepath' arguments cannot both be provided." - r"Please only provide one." - ) - with pytest.raises(DatasetError, match=pattern): - GBQQueryDataSet(sql=SQL_QUERY, filepath=sql_file) diff --git a/tests/extras/datasets/pandas/test_generic_dataset.py b/tests/extras/datasets/pandas/test_generic_dataset.py deleted file mode 100644 index 23feb861e8..0000000000 --- a/tests/extras/datasets/pandas/test_generic_dataset.py +++ /dev/null @@ -1,383 +0,0 @@ -from pathlib import Path, PurePosixPath -from time import sleep - -import pandas as pd -import pytest -from adlfs import AzureBlobFileSystem -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from pandas._testing import assert_frame_equal -from s3fs import S3FileSystem - -from kedro.extras.datasets.pandas import GenericDataSet -from kedro.io import DatasetError, Version -from kedro.io.core import PROTOCOL_DELIMITER, generate_timestamp - - -@pytest.fixture -def filepath_sas(tmp_path): - return tmp_path / "test.sas7bdat" - - -@pytest.fixture -def filepath_csv(tmp_path): - return tmp_path / "test.csv" - - -@pytest.fixture -def filepath_html(tmp_path): - return tmp_path / "test.html" - - -# pylint: disable = line-too-long -@pytest.fixture() -def sas_binary(): - return b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xc2\xea\x81`\xb3\x14\x11\xcf\xbd\x92\x08\x00\t\xc71\x8c\x18\x1f\x10\x11""\x002"\x01\x022\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x18\x1f\x10\x11""\x002"\x01\x022\x042\x01""\x00\x00\x00\x00\x10\x03\x01\x00\x00\x00\x00\x00\x00\x00\x00SAS FILEAIRLINE DATA \x00\x00\xc0\x95j\xbe\xd6A\x00\x00\xc0\x95j\xbe\xd6A\x00\x00\x00\x00\x00 \xbc@\x00\x00\x00\x00\x00 \xbc@\x00\x04\x00\x00\x00\x10\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x009.0000M0WIN\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00WIN\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xc0\x95LN\xaf\xf0LN\xaf\xf0LN\xaf\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00jIW-\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00kIW-\x00\x00\x00\x00\x00\x00\x00\x00<\x04\x00\x00\x00\x02-\x00\r\x00\x00\x00 \x0e\x00\x00\xe0\x01\x00\x00\x00\x00\x00\x00\x14\x0e\x00\x00\x0c\x00\x00\x00\x00\x00\x00\x00\xe4\x0c\x00\x000\x01\x00\x00\x00\x00\x00\x00H\x0c\x00\x00\x9c\x00\x00\x00\x00\x01\x00\x00\x04\x0c\x00\x00D\x00\x00\x00\x00\x01\x00\x00\xa8\x0b\x00\x00\\\x00\x00\x00\x00\x01\x00\x00t\x0b\x00\x004\x00\x00\x00\x00\x00\x00\x00@\x0b\x00\x004\x00\x00\x00\x00\x00\x00\x00\x0c\x0b\x00\x004\x00\x00\x00\x00\x00\x00\x00\xd8\n\x00\x004\x00\x00\x00\x00\x00\x00\x00\xa4\n\x00\x004\x00\x00\x00\x00\x00\x00\x00p\n\x00\x004\x00\x00\x00\x00\x00\x00\x00p\n\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00p\x9e@\x00\x00\x00@\x8bl\xf3?\x00\x00\x00\xc0\x9f\x1a\xcf?\x00\x00\x00\xa0w\x9c\xc2?\x00\x00\x00\x00\xd7\xa3\xf6?\x00\x00\x00\x00\x81\x95\xe3?\x00t\x9e@\x00\x00\x00\xe0\xfb\xa9\xf5?\x00\x00\x00\x00\xd7\xa3\xd0?\x00\x00\x00`\xb3\xea\xcb?\x00\x00\x00 \xdd$\xf6?\x00\x00\x00\x00T\xe3\xe1?\x00x\x9e@\x00\x00\x00\xc0\x9f\x1a\xf9?\x00\x00\x00\x80\xc0\xca\xd1?\x00\x00\x00\xc0m4\xd4?\x00\x00\x00\x80?5\xf6?\x00\x00\x00 \x04V\xe2?\x00|\x9e@\x00\x00\x00\x00\x02+\xff?\x00\x00\x00@\x0c\x02\xd3?\x00\x00\x00\xc0K7\xd9?\x00\x00\x00\xc0\xcc\xcc\xf8?\x00\x00\x00\xc0I\x0c\xe2?\x00\x80\x9e@\x00\x00\x00`\xb8\x1e\x02@\x00\x00\x00@\n\xd7\xd3?\x00\x00\x00\xc0\x10\xc7\xd6?\x00\x00\x00\x00\xfe\xd4\xfc?\x00\x00\x00@5^\xe2?\x00\x84\x9e@\x00\x00\x00\x80\x16\xd9\x05@\x00\x00\x00\xe0\xa5\x9b\xd4?\x00\x00\x00`\xc5\xfe\xd6?\x00\x00\x00`\xe5\xd0\xfe?\x00\x00\x00 \x83\xc0\xe6?\x00\x88\x9e@\x00\x00\x00@33\x08@\x00\x00\x00\xe0\xa3p\xd5?\x00\x00\x00`\x8f\xc2\xd9?\x00\x00\x00@\x8bl\xff?\x00\x00\x00\x00\xfe\xd4\xe8?\x00\x8c\x9e@\x00\x00\x00\xe0\xf9~\x0c@\x00\x00\x00`ff\xd6?\x00\x00\x00\xe0\xb3Y\xd9?\x00\x00\x00`\x91\xed\x00@\x00\x00\x00\xc0\xc8v\xea?\x00\x90\x9e@\x00\x00\x00\x00\xfe\xd4\x0f@\x00\x00\x00\xc0\x9f\x1a\xd7?\x00\x00\x00\x00\xf7u\xd8?\x00\x00\x00@\xe1z\x03@\x00\x00\x00\xa0\x99\x99\xe9?\x00\x94\x9e@\x00\x00\x00\x80\x14\xae\x11@\x00\x00\x00@\x89A\xd8?\x00\x00\x00\xa0\xed|\xd3?\x00\x00\x00\xa0\xef\xa7\x05@\x00\x00\x00\x00\xd5x\xed?\x00\x98\x9e@\x00\x00\x00 \x83@\x12@\x00\x00\x00\xe0$\x06\xd9?\x00\x00\x00`\x81\x04\xd5?\x00\x00\x00`\xe3\xa5\x05@\x00\x00\x00\xa0n\x12\xf1?\x00\x9c\x9e@\x00\x00\x00\x80=\x8a\x15@\x00\x00\x00\x80\x95C\xdb?\x00\x00\x00\xa0\xab\xad\xd8?\x00\x00\x00\xa0\x9b\xc4\x06@\x00\x00\x00\xc0\xf7S\xf1?\x00\xa0\x9e@\x00\x00\x00\xc0K7\x16@\x00\x00\x00 X9\xdc?\x00\x00\x00@io\xd4?\x00\x00\x00\xa0E\xb6\x08@\x00\x00\x00\x00-\xb2\xf7?\x00\xa4\x9e@\x00\x00\x00\x00)\xdc\x15@\x00\x00\x00\xe0\xa3p\xdd?\x00\x00\x00@\xa2\xb4\xd3?\x00\x00\x00 \xdb\xf9\x08@\x00\x00\x00\xe0\xa7\xc6\xfb?\x00\xa8\x9e@\x00\x00\x00\xc0\xccL\x17@\x00\x00\x00\x80=\n\xdf?\x00\x00\x00@\x116\xd8?\x00\x00\x00\x00\xd5x\t@\x00\x00\x00`\xe5\xd0\xfe?\x00\xac\x9e@\x00\x00\x00 \x06\x81\x1b@\x00\x00\x00\xe0&1\xe0?\x00\x00\x00 \x83\xc0\xda?\x00\x00\x00\xc0\x9f\x1a\n@\x00\x00\x00\xc0\xf7S\x00@\x00\xb0\x9e@\x00\x00\x00\x80\xc0J\x1f@\x00\x00\x00\xc0K7\xe1?\x00\x00\x00\xa0\x87\x85\xe0?\x00\x00\x00\xa0\xc6K\x0b@\x00\x00\x00@\xb6\xf3\xff?\x00\xb4\x9e@\x00\x00\x00\xa0p="@\x00\x00\x00\xc0I\x0c\xe2?\x00\x00\x00\xa0\x13\xd0\xe2?\x00\x00\x00`\xe7\xfb\x0c@\x00\x00\x00\x00V\x0e\x02@\x00\xb8\x9e@\x00\x00\x00\xe0$\x06%@\x00\x00\x00 \x83\xc0\xe2?\x00\x00\x00\xe0H.\xe1?\x00\x00\x00\xa0\xc6K\x10@\x00\x00\x00\xc0\x9d\xef\x05@\x00\xbc\x9e@\x00\x00\x00\x80=\n*@\x00\x00\x00\x80l\xe7\xe3?\x00\x00\x00@io\xdc?\x00\x00\x00@\n\xd7\x12@\x00\x00\x00`\x12\x83\x0c@\x00\xc0\x9e@\x00\x00\x00\xc0\xa1\x85.@\x00\x00\x00@\xdfO\xe5?\x00\x00\x00\xa0e\x88\xd3?\x00\x00\x00@5\xde\x14@\x00\x00\x00\x80h\x11\x13@\x00\xc4\x9e@\x00\x00\x00\xc0 P0@\x00\x00\x00 Zd\xe7?\x00\x00\x00`\x7f\xd9\xcd?\x00\x00\x00\xe0\xa7F\x16@\x00\x00\x00\xa0C\x0b\x1a@\x00\xc8\x9e@\x00\x00\x00 \x83\x000@\x00\x00\x00@\x8d\x97\xea?\x00\x00\x00\xe06\x1a\xc8?\x00\x00\x00@\xe1\xfa\x15@\x00\x00\x00@\x0c\x82\x1e@\x00\xcc\x9e@\x00\x00\x00 \x83\xc0/@\x00\x00\x00\xc0\xf3\xfd\xec?\x00\x00\x00`\xf7\xe4\xc9?\x00\x00\x00 \x04V\x15@\x00\x00\x00\x80\x93X!@\x00\xd0\x9e@\x00\x00\x00\xe0x\xa90@\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\xa0\xd4\t\xd0?\x00\x00\x00\xa0Ga\x15@\x00\x00\x00\xe0x\xa9 @\x00\xd4\x9e@\x00\x00\x00\x80\x95\x031@\x00\x00\x00@`\xe5\xf0?\x00\x00\x00@@\x13\xd1?\x00\x00\x00`\xe3\xa5\x16@\x00\x00\x00 /\x1d!@\x00\xd8\x9e@\x00\x00\x00\x80\x14N3@\x00\x00\x00\x80\x93\x18\xf2?\x00\x00\x00\xa0\xb2\x0c\xd1?\x00\x00\x00\x00\x7f\xea\x16@\x00\x00\x00\xa0\x18\x04#@\x00\xdc\x9e@\x00\x00\x00\x80\x93\xb82@\x00\x00\x00@\xb6\xf3\xf3?\x00\x00\x00\xc0\xeas\xcd?\x00\x00\x00\x00T\xe3\x16@\x00\x00\x00\x80\xbe\x1f"@\x00\xe0\x9e@\x00\x00\x00\x00\x00@3@\x00\x00\x00\x00\x00\x00\xf6?\x00\x00\x00\xc0\xc1\x17\xd6?\x00\x00\x00\xc0I\x0c\x17@\x00\x00\x00\xe0$\x86 @\x00\xe4\x9e@\x00\x00\x00\xc0\xa1\xa54@\x00\x00\x00`9\xb4\xf8?\x00\x00\x00@\xe8\xd9\xdc?\x00\x00\x00@\x0c\x82\x17@\x00\x00\x00@`\xe5\x1d@\x00\xe8\x9e@\x00\x00\x00 \xdb\xb96@\x00\x00\x00\xe0|?\xfb?\x00\x00\x00@p\xce\xe2?\x00\x00\x00\x80\x97n\x18@\x00\x00\x00\x00\x7fj\x1c@\x00\xec\x9e@\x00\x00\x00\xc0v\x9e7@\x00\x00\x00\xc0\xc8v\xfc?\x00\x00\x00\x80q\x1b\xe1?\x00\x00\x00\xc0rh\x1b@\x00\x00\x00\xe0\xf9~\x1b@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xfe\xfb\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00p\x00\r\x00\x00\x00\x00\x00\x00\x00\xfe\xfb\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00`\x00\x0b\x00\x00\x00\x00\x00\x00\x00\xfe\xfb\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00L\x00\r\x00\x00\x00\x00\x00\x00\x00\xfe\xfb\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00<\x00\t\x00\x00\x00\x00\x00\x00\x00\xfe\xfb\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00(\x00\x0f\x00\x00\x00\x00\x00\x00\x00\xfe\xfb\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00\x04\x00\x00\x00\x00\x00\x00\x00\xfc\xff\xff\xffP\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x04\x01\x00\x04\x00\x00\x00\x08\x00\x00\x00\x00\x04\x01\x00\x0c\x00\x00\x00\x08\x00\x00\x00\x00\x04\x01\x00\x14\x00\x00\x00\x08\x00\x00\x00\x00\x04\x01\x00\x1c\x00\x00\x00\x08\x00\x00\x00\x00\x04\x01\x00$\x00\x00\x00\x08\x00\x00\x00\x00\x04\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1c\x00\x04\x00\x00\x00\x00\x00$\x00\x01\x00\x00\x00\x00\x008\x00\x01\x00\x00\x00\x00\x00H\x00\x01\x00\x00\x00\x00\x00\\\x00\x01\x00\x00\x00\x00\x00l\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xfd\xff\xff\xff\x90\x00\x10\x00\x80\x00\x00\x00\x00\x00\x00\x00Written by SAS\x00\x00YEARyearY\x00\x00\x00level of output\x00W\x00\x00\x00wage rate\x00\x00\x00R\x00\x00\x00interest rate\x00\x00\x00L\x00\x00\x00labor input\x00K\x00\x00\x00capital input\x00\x00\x00\x01\x00\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xfc\xff\xff0\x00\x00\x00\x04\x00\x00\x00\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x07\x00\x00\x00\x00\x00\x00\xfc\xff\xff\xff\x01\x00\x00\x00\x06\x00\x00\x00\x01\x00\x00\x00\x06\x00\x00\x00\xfd\xff\xff\xff\x01\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\xff\xff\xff\xff\x01\x00\x00\x00\x05\x00\x00\x00\x01\x00\x00\x00\x05\x00\x00\x00\xfe\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xfb\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xfa\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf9\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf6\xf6\xf6\xf6\x06\x00\x00\x00\x00\x00\x00\x00\xf7\xf7\xf7\xf7\xcd\x00\x00\x00\x0e\x00\x00\x00\x00\x00\x00\x00\x110\x02\x00,\x00\x00\x00 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00.\x00\x00\x00\x00\x10\x00\x00\x00\x00\x00\x00 \x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00kIW-\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x01\x00\x00\x00\x0c\x00\x00\x00\x01\x00\x00\x00\x0e\x00\x00\x00\x01\x00\x00\x00-\x00\x00\x00\x01\x00\x00\x00\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x0c\x00\x10\x00\x00\x00\x14\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x08\x00\x00\x00\x1c\x00\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\x01\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\\\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' - - -@pytest.fixture -def sas_data_set(filepath_sas, fs_args): - return GenericDataSet( - filepath=filepath_sas.as_posix(), - file_format="sas", - load_args={"format": "sas7bdat"}, - fs_args=fs_args, - ) - - -@pytest.fixture -def html_data_set(filepath_html, fs_args): - return GenericDataSet( - filepath=filepath_html.as_posix(), - file_format="html", - fs_args=fs_args, - save_args={"index": False}, - ) - - -@pytest.fixture -def sas_data_set_bad_config(filepath_sas, fs_args): - return GenericDataSet( - filepath=filepath_sas.as_posix(), - file_format="sas", - load_args={}, # SAS reader requires a type param - fs_args=fs_args, - ) - - -@pytest.fixture -def versioned_csv_data_set(filepath_csv, load_version, save_version): - return GenericDataSet( - filepath=filepath_csv.as_posix(), - file_format="csv", - version=Version(load_version, save_version), - save_args={"index": False}, - ) - - -@pytest.fixture -def csv_data_set(filepath_csv): - return GenericDataSet( - filepath=filepath_csv.as_posix(), - file_format="csv", - save_args={"index": False}, - ) - - -@pytest.fixture -def dummy_dataframe(): - return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) - - -class TestGenericSasDataSet: - def test_load(self, sas_binary, sas_data_set, filepath_sas): - filepath_sas.write_bytes(sas_binary) - df = sas_data_set.load() - assert df.shape == (32, 6) - - def test_save_fail(self, sas_data_set, dummy_dataframe): - pattern = ( - "Unable to retrieve 'pandas.DataFrame.to_sas' method, please ensure that your " - "'file_format' parameter has been defined correctly as per the Pandas API " - "https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html" - ) - with pytest.raises(DatasetError, match=pattern): - sas_data_set.save(dummy_dataframe) - # Pandas does not implement a SAS writer - - def test_bad_load(self, sas_data_set_bad_config, sas_binary, filepath_sas): - # SAS reader requires a format param e.g. sas7bdat - filepath_sas.write_bytes(sas_binary) - pattern = "you must specify a format string" - with pytest.raises(DatasetError, match=pattern): - sas_data_set_bad_config.load() - - @pytest.mark.parametrize( - "filepath,instance_type,credentials", - [ - ("s3://bucket/file.sas7bdat", S3FileSystem, {}), - ("file:///tmp/test.sas7bdat", LocalFileSystem, {}), - ("/tmp/test.sas7bdat", LocalFileSystem, {}), - ("gcs://bucket/file.sas7bdat", GCSFileSystem, {}), - ("https://example.com/file.sas7bdat", HTTPFileSystem, {}), - ( - "abfs://bucket/file.sas7bdat", - AzureBlobFileSystem, - {"account_name": "test", "account_key": "test"}, - ), - ], - ) - def test_protocol_usage(self, filepath, instance_type, credentials): - data_set = GenericDataSet( - filepath=filepath, file_format="sas", credentials=credentials - ) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.csv" - data_set = GenericDataSet(filepath=filepath, file_format="sas") - assert data_set._version_cache.currsize == 0 # no cache if unversioned - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - assert data_set._version_cache.currsize == 0 - - -class TestGenericCSVDataSetVersioned: - def test_version_str_repr(self, filepath_csv, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = filepath_csv.as_posix() - ds = GenericDataSet(filepath=filepath, file_format="csv") - ds_versioned = GenericDataSet( - filepath=filepath, - file_format="csv", - version=Version(load_version, save_version), - ) - assert filepath in str(ds) - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "GenericDataSet" in str(ds_versioned) - assert "GenericDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - - def test_save_and_load(self, versioned_csv_data_set, dummy_dataframe): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_csv_data_set.save(dummy_dataframe) - reloaded_df = versioned_csv_data_set.load() - assert_frame_equal(dummy_dataframe, reloaded_df) - - def test_multiple_loads( - self, versioned_csv_data_set, dummy_dataframe, filepath_csv - ): - """Test that if a new version is created mid-run, by an - external system, it won't be loaded in the current run.""" - versioned_csv_data_set.save(dummy_dataframe) - versioned_csv_data_set.load() - v1 = versioned_csv_data_set.resolve_load_version() - - sleep(0.5) - # force-drop a newer version into the same location - v_new = generate_timestamp() - GenericDataSet( - filepath=filepath_csv.as_posix(), - file_format="csv", - version=Version(v_new, v_new), - ).save(dummy_dataframe) - - versioned_csv_data_set.load() - v2 = versioned_csv_data_set.resolve_load_version() - - assert v2 == v1 # v2 should not be v_new! - ds_new = GenericDataSet( - filepath=filepath_csv.as_posix(), - file_format="csv", - version=Version(None, None), - ) - assert ( - ds_new.resolve_load_version() == v_new - ) # new version is discoverable by a new instance - - def test_multiple_saves(self, dummy_dataframe, filepath_csv): - """Test multiple cycles of save followed by load for the same dataset""" - ds_versioned = GenericDataSet( - filepath=filepath_csv.as_posix(), - file_format="csv", - version=Version(None, None), - ) - - # first save - ds_versioned.save(dummy_dataframe) - first_save_version = ds_versioned.resolve_save_version() - first_load_version = ds_versioned.resolve_load_version() - assert first_load_version == first_save_version - - # second save - sleep(0.5) - ds_versioned.save(dummy_dataframe) - second_save_version = ds_versioned.resolve_save_version() - second_load_version = ds_versioned.resolve_load_version() - assert second_load_version == second_save_version - assert second_load_version > first_load_version - - # another dataset - ds_new = GenericDataSet( - filepath=filepath_csv.as_posix(), - file_format="csv", - version=Version(None, None), - ) - assert ds_new.resolve_load_version() == second_load_version - - def test_release_instance_cache(self, dummy_dataframe, filepath_csv): - """Test that cache invalidation does not affect other instances""" - ds_a = GenericDataSet( - filepath=filepath_csv.as_posix(), - file_format="csv", - version=Version(None, None), - ) - assert ds_a._version_cache.currsize == 0 - ds_a.save(dummy_dataframe) # create a version - assert ds_a._version_cache.currsize == 2 - - ds_b = GenericDataSet( - filepath=filepath_csv.as_posix(), - file_format="csv", - version=Version(None, None), - ) - assert ds_b._version_cache.currsize == 0 - ds_b.resolve_save_version() - assert ds_b._version_cache.currsize == 1 - ds_b.resolve_load_version() - assert ds_b._version_cache.currsize == 2 - - ds_a.release() - - # dataset A cache is cleared - assert ds_a._version_cache.currsize == 0 - - # dataset B cache is unaffected - assert ds_b._version_cache.currsize == 2 - - def test_no_versions(self, versioned_csv_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for GenericDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_csv_data_set.load() - - def test_exists(self, versioned_csv_data_set, dummy_dataframe): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_csv_data_set.exists() - versioned_csv_data_set.save(dummy_dataframe) - assert versioned_csv_data_set.exists() - - def test_prevent_overwrite(self, versioned_csv_data_set, dummy_dataframe): - """Check the error when attempting to override the data set if the - corresponding Generic (csv) file for a given save version already exists.""" - versioned_csv_data_set.save(dummy_dataframe) - pattern = ( - r"Save path \'.+\' for GenericDataSet\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_csv_data_set, load_version, save_version, dummy_dataframe - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for GenericDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) - - def test_versioning_existing_dataset( - self, csv_data_set, versioned_csv_data_set, dummy_dataframe - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - csv_data_set.save(dummy_dataframe) - assert csv_data_set.exists() - assert csv_data_set._filepath == versioned_csv_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_csv_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) - - # Remove non-versioned dataset and try again - Path(csv_data_set._filepath.as_posix()).unlink() - versioned_csv_data_set.save(dummy_dataframe) - assert versioned_csv_data_set.exists() - - -class TestGenericHtmlDataSet: - def test_save_and_load(self, dummy_dataframe, html_data_set): - html_data_set.save(dummy_dataframe) - df = html_data_set.load() - assert_frame_equal(dummy_dataframe, df[0]) - - -class TestBadGenericDataSet: - def test_bad_file_format_argument(self): - ds = GenericDataSet(filepath="test.kedro", file_format="kedro") - - pattern = ( - "Unable to retrieve 'pandas.read_kedro' method, please ensure that your 'file_format' " - "parameter has been defined correctly as per the Pandas API " - "https://pandas.pydata.org/docs/reference/io.html" - ) - - with pytest.raises(DatasetError, match=pattern): - _ = ds.load() - - pattern2 = ( - "Unable to retrieve 'pandas.DataFrame.to_kedro' method, please ensure that your 'file_format' " - "parameter has been defined correctly as per the Pandas API " - "https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html" - ) - with pytest.raises(DatasetError, match=pattern2): - ds.save(pd.DataFrame([1])) - - @pytest.mark.parametrize( - "file_format", - [ - "clipboard", - "sql_table", - "sql", - "numpy", - "records", - ], - ) - def test_generic_no_filepaths(self, file_format): - error = ( - "Cannot create a dataset of file_format " - f"'{file_format}' as it does not support a filepath target/source" - ) - - with pytest.raises(DatasetError, match=error): - _ = GenericDataSet( - filepath="/file/thing.file", file_format=file_format - ).load() - with pytest.raises(DatasetError, match=error): - GenericDataSet(filepath="/file/thing.file", file_format=file_format).save( - pd.DataFrame([1]) - ) diff --git a/tests/extras/datasets/pandas/test_hdf_dataset.py b/tests/extras/datasets/pandas/test_hdf_dataset.py deleted file mode 100644 index 0580e510b4..0000000000 --- a/tests/extras/datasets/pandas/test_hdf_dataset.py +++ /dev/null @@ -1,245 +0,0 @@ -from pathlib import Path, PurePosixPath - -import pandas as pd -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from pandas.testing import assert_frame_equal -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.pandas import HDFDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version - -HDF_KEY = "data" - - -@pytest.fixture -def filepath_hdf(tmp_path): - return (tmp_path / "test.h5").as_posix() - - -@pytest.fixture -def hdf_data_set(filepath_hdf, load_args, save_args, mocker, fs_args): - HDFDataSet._lock = mocker.MagicMock() - return HDFDataSet( - filepath=filepath_hdf, - key=HDF_KEY, - load_args=load_args, - save_args=save_args, - fs_args=fs_args, - ) - - -@pytest.fixture -def versioned_hdf_data_set(filepath_hdf, load_version, save_version): - return HDFDataSet( - filepath=filepath_hdf, key=HDF_KEY, version=Version(load_version, save_version) - ) - - -@pytest.fixture -def dummy_dataframe(): - return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) - - -class TestHDFDataSet: - def test_save_and_load(self, hdf_data_set, dummy_dataframe): - """Test saving and reloading the data set.""" - hdf_data_set.save(dummy_dataframe) - reloaded = hdf_data_set.load() - assert_frame_equal(dummy_dataframe, reloaded) - assert hdf_data_set._fs_open_args_load == {} - assert hdf_data_set._fs_open_args_save == {"mode": "wb"} - - def test_exists(self, hdf_data_set, dummy_dataframe): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not hdf_data_set.exists() - hdf_data_set.save(dummy_dataframe) - assert hdf_data_set.exists() - - @pytest.mark.parametrize( - "load_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_load_extra_params(self, hdf_data_set, load_args): - """Test overriding the default load arguments.""" - for key, value in load_args.items(): - assert hdf_data_set._load_args[key] == value - - @pytest.mark.parametrize( - "save_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_save_extra_params(self, hdf_data_set, save_args): - """Test overriding the default save arguments.""" - for key, value in save_args.items(): - assert hdf_data_set._save_args[key] == value - - @pytest.mark.parametrize( - "fs_args", - [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], - indirect=True, - ) - def test_open_extra_args(self, hdf_data_set, fs_args): - assert hdf_data_set._fs_open_args_load == fs_args["open_args_load"] - assert hdf_data_set._fs_open_args_save == {"mode": "wb"} # default unchanged - - def test_load_missing_file(self, hdf_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set HDFDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - hdf_data_set.load() - - @pytest.mark.parametrize( - "filepath,instance_type", - [ - ("s3://bucket/file.h5", S3FileSystem), - ("file:///tmp/test.h5", LocalFileSystem), - ("/tmp/test.h5", LocalFileSystem), - ("gcs://bucket/file.h5", GCSFileSystem), - ("https://example.com/file.h5", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, filepath, instance_type): - data_set = HDFDataSet(filepath=filepath, key=HDF_KEY) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.h5" - data_set = HDFDataSet(filepath=filepath, key=HDF_KEY) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - def test_save_and_load_df_with_categorical_variables(self, hdf_data_set): - """Test saving and reloading the data set with categorical variables.""" - df = pd.DataFrame( - {"A": [1, 2, 3], "B": pd.Series(list("aab")).astype("category")} - ) - hdf_data_set.save(df) - reloaded = hdf_data_set.load() - assert_frame_equal(df, reloaded) - - def test_thread_lock_usage(self, hdf_data_set, dummy_dataframe, mocker): - """Test thread lock usage.""" - # pylint: disable=no-member - mocked_lock = HDFDataSet._lock - mocked_lock.assert_not_called() - - hdf_data_set.save(dummy_dataframe) - calls = [ - mocker.call.__enter__(), # pylint: disable=unnecessary-dunder-call - mocker.call.__exit__(None, None, None), - ] - mocked_lock.assert_has_calls(calls) - - mocked_lock.reset_mock() - hdf_data_set.load() - mocked_lock.assert_has_calls(calls) - - -class TestHDFDataSetVersioned: - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "test.h5" - ds = HDFDataSet(filepath=filepath, key=HDF_KEY) - ds_versioned = HDFDataSet( - filepath=filepath, key=HDF_KEY, version=Version(load_version, save_version) - ) - assert filepath in str(ds) - assert "version" not in str(ds) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "HDFDataSet" in str(ds_versioned) - assert "HDFDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - assert "key" in str(ds_versioned) - assert "key" in str(ds) - - def test_save_and_load(self, versioned_hdf_data_set, dummy_dataframe): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_hdf_data_set.save(dummy_dataframe) - reloaded_df = versioned_hdf_data_set.load() - assert_frame_equal(dummy_dataframe, reloaded_df) - - def test_no_versions(self, versioned_hdf_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for HDFDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_hdf_data_set.load() - - def test_exists(self, versioned_hdf_data_set, dummy_dataframe): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_hdf_data_set.exists() - versioned_hdf_data_set.save(dummy_dataframe) - assert versioned_hdf_data_set.exists() - - def test_prevent_overwrite(self, versioned_hdf_data_set, dummy_dataframe): - """Check the error when attempting to override the data set if the - corresponding hdf file for a given save version already exists.""" - versioned_hdf_data_set.save(dummy_dataframe) - pattern = ( - r"Save path \'.+\' for HDFDataSet\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_hdf_data_set.save(dummy_dataframe) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_hdf_data_set, load_version, save_version, dummy_dataframe - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for HDFDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_hdf_data_set.save(dummy_dataframe) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - HDFDataSet( - filepath="https://example.com/file.h5", - key=HDF_KEY, - version=Version(None, None), - ) - - def test_versioning_existing_dataset( - self, hdf_data_set, versioned_hdf_data_set, dummy_dataframe - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - hdf_data_set.save(dummy_dataframe) - assert hdf_data_set.exists() - assert hdf_data_set._filepath == versioned_hdf_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_hdf_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_hdf_data_set.save(dummy_dataframe) - - # Remove non-versioned dataset and try again - Path(hdf_data_set._filepath.as_posix()).unlink() - versioned_hdf_data_set.save(dummy_dataframe) - assert versioned_hdf_data_set.exists() diff --git a/tests/extras/datasets/pandas/test_json_dataset.py b/tests/extras/datasets/pandas/test_json_dataset.py deleted file mode 100644 index fe5c7f8c42..0000000000 --- a/tests/extras/datasets/pandas/test_json_dataset.py +++ /dev/null @@ -1,241 +0,0 @@ -from pathlib import Path, PurePosixPath - -import pandas as pd -import pytest -from adlfs import AzureBlobFileSystem -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from pandas.testing import assert_frame_equal -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.pandas import JSONDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version - - -@pytest.fixture -def filepath_json(tmp_path): - return (tmp_path / "test.json").as_posix() - - -@pytest.fixture -def json_data_set(filepath_json, load_args, save_args, fs_args): - return JSONDataSet( - filepath=filepath_json, - load_args=load_args, - save_args=save_args, - fs_args=fs_args, - ) - - -@pytest.fixture -def versioned_json_data_set(filepath_json, load_version, save_version): - return JSONDataSet( - filepath=filepath_json, version=Version(load_version, save_version) - ) - - -@pytest.fixture -def dummy_dataframe(): - return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) - - -class TestJSONDataSet: - def test_save_and_load(self, json_data_set, dummy_dataframe): - """Test saving and reloading the data set.""" - json_data_set.save(dummy_dataframe) - reloaded = json_data_set.load() - assert_frame_equal(dummy_dataframe, reloaded) - - def test_exists(self, json_data_set, dummy_dataframe): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not json_data_set.exists() - json_data_set.save(dummy_dataframe) - assert json_data_set.exists() - - @pytest.mark.parametrize( - "load_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_load_extra_params(self, json_data_set, load_args): - """Test overriding the default load arguments.""" - for key, value in load_args.items(): - assert json_data_set._load_args[key] == value - - @pytest.mark.parametrize( - "save_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_save_extra_params(self, json_data_set, save_args): - """Test overriding the default save arguments.""" - for key, value in save_args.items(): - assert json_data_set._save_args[key] == value - - @pytest.mark.parametrize( - "load_args,save_args", - [ - ({"storage_options": {"a": "b"}}, {}), - ({}, {"storage_options": {"a": "b"}}), - ({"storage_options": {"a": "b"}}, {"storage_options": {"x": "y"}}), - ], - ) - def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): - filepath = str(tmp_path / "test.csv") - - ds = JSONDataSet(filepath=filepath, load_args=load_args, save_args=save_args) - - records = [r for r in caplog.records if r.levelname == "WARNING"] - expected_log_message = ( - f"Dropping 'storage_options' for {filepath}, " - f"please specify them under 'fs_args' or 'credentials'." - ) - assert records[0].getMessage() == expected_log_message - assert "storage_options" not in ds._save_args - assert "storage_options" not in ds._load_args - - def test_load_missing_file(self, json_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set JSONDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - json_data_set.load() - - @pytest.mark.parametrize( - "filepath,instance_type,credentials,load_path", - [ - ("s3://bucket/file.json", S3FileSystem, {}, "s3://bucket/file.json"), - ("file:///tmp/test.json", LocalFileSystem, {}, "/tmp/test.json"), - ("/tmp/test.json", LocalFileSystem, {}, "/tmp/test.json"), - ("gcs://bucket/file.json", GCSFileSystem, {}, "gcs://bucket/file.json"), - ( - "https://example.com/file.json", - HTTPFileSystem, - {}, - "https://example.com/file.json", - ), - ( - "abfs://bucket/file.csv", - AzureBlobFileSystem, - {"account_name": "test", "account_key": "test"}, - "abfs://bucket/file.csv", - ), - ], - ) - def test_protocol_usage( - self, filepath, instance_type, credentials, load_path, mocker - ): - data_set = JSONDataSet(filepath=filepath, credentials=credentials) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - mock_pandas_call = mocker.patch("pandas.read_json") - data_set.load() - assert mock_pandas_call.call_count == 1 - assert mock_pandas_call.call_args_list[0][0][0] == load_path - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.json" - data_set = JSONDataSet(filepath=filepath) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - -class TestJSONDataSetVersioned: - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "test.json" - ds = JSONDataSet(filepath=filepath) - ds_versioned = JSONDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - assert filepath in str(ds) - assert "version" not in str(ds) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "JSONDataSet" in str(ds_versioned) - assert "JSONDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - - def test_save_and_load(self, versioned_json_data_set, dummy_dataframe): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_json_data_set.save(dummy_dataframe) - reloaded_df = versioned_json_data_set.load() - assert_frame_equal(dummy_dataframe, reloaded_df) - - def test_no_versions(self, versioned_json_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for JSONDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_json_data_set.load() - - def test_exists(self, versioned_json_data_set, dummy_dataframe): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_json_data_set.exists() - versioned_json_data_set.save(dummy_dataframe) - assert versioned_json_data_set.exists() - - def test_prevent_overwrite(self, versioned_json_data_set, dummy_dataframe): - """Check the error when attempting to override the data set if the - corresponding hdf file for a given save version already exists.""" - versioned_json_data_set.save(dummy_dataframe) - pattern = ( - r"Save path \'.+\' for JSONDataSet\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_json_data_set.save(dummy_dataframe) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_json_data_set, load_version, save_version, dummy_dataframe - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for JSONDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_json_data_set.save(dummy_dataframe) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - JSONDataSet( - filepath="https://example.com/file.json", version=Version(None, None) - ) - - def test_versioning_existing_dataset( - self, json_data_set, versioned_json_data_set, dummy_dataframe - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - json_data_set.save(dummy_dataframe) - assert json_data_set.exists() - assert json_data_set._filepath == versioned_json_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_json_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_json_data_set.save(dummy_dataframe) - - # Remove non-versioned dataset and try again - Path(json_data_set._filepath.as_posix()).unlink() - versioned_json_data_set.save(dummy_dataframe) - assert versioned_json_data_set.exists() diff --git a/tests/extras/datasets/pandas/test_parquet_dataset.py b/tests/extras/datasets/pandas/test_parquet_dataset.py deleted file mode 100644 index 5e415bd75b..0000000000 --- a/tests/extras/datasets/pandas/test_parquet_dataset.py +++ /dev/null @@ -1,344 +0,0 @@ -from pathlib import Path, PurePosixPath - -import pandas as pd -import pyarrow.parquet as pq -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from pandas.testing import assert_frame_equal -from pyarrow.fs import FSSpecHandler, PyFileSystem -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.pandas import ParquetDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version - -FILENAME = "test.parquet" - - -@pytest.fixture -def filepath_parquet(tmp_path): - return (tmp_path / FILENAME).as_posix() - - -@pytest.fixture -def parquet_data_set(filepath_parquet, load_args, save_args, fs_args): - return ParquetDataSet( - filepath=filepath_parquet, - load_args=load_args, - save_args=save_args, - fs_args=fs_args, - ) - - -@pytest.fixture -def versioned_parquet_data_set(filepath_parquet, load_version, save_version): - return ParquetDataSet( - filepath=filepath_parquet, version=Version(load_version, save_version) - ) - - -@pytest.fixture -def dummy_dataframe(): - return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) - - -class TestParquetDataSet: - def test_credentials_propagated(self, mocker): - """Test propagating credentials for connecting to GCS""" - mock_fs = mocker.patch("fsspec.filesystem") - credentials = {"key": "value"} - - ParquetDataSet(filepath=FILENAME, credentials=credentials) - - mock_fs.assert_called_once_with("file", auto_mkdir=True, **credentials) - - def test_save_and_load(self, tmp_path, dummy_dataframe): - """Test saving and reloading the data set.""" - filepath = (tmp_path / FILENAME).as_posix() - data_set = ParquetDataSet(filepath=filepath) - data_set.save(dummy_dataframe) - reloaded = data_set.load() - assert_frame_equal(dummy_dataframe, reloaded) - - files = [child.is_file() for child in tmp_path.iterdir()] - assert all(files) - assert len(files) == 1 - - def test_save_and_load_non_existing_dir(self, tmp_path, dummy_dataframe): - """Test saving and reloading the data set to non-existing directory.""" - filepath = (tmp_path / "non-existing" / FILENAME).as_posix() - data_set = ParquetDataSet(filepath=filepath) - data_set.save(dummy_dataframe) - reloaded = data_set.load() - assert_frame_equal(dummy_dataframe, reloaded) - - def test_exists(self, parquet_data_set, dummy_dataframe): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not parquet_data_set.exists() - parquet_data_set.save(dummy_dataframe) - assert parquet_data_set.exists() - - @pytest.mark.parametrize( - "load_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_load_extra_params(self, parquet_data_set, load_args): - """Test overriding the default load arguments.""" - for key, value in load_args.items(): - assert parquet_data_set._load_args[key] == value - - @pytest.mark.parametrize( - "save_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_save_extra_params(self, parquet_data_set, save_args): - """Test overriding the default save arguments.""" - for key, value in save_args.items(): - assert parquet_data_set._save_args[key] == value - - @pytest.mark.parametrize( - "load_args,save_args", - [ - ({"storage_options": {"a": "b"}}, {}), - ({}, {"storage_options": {"a": "b"}}), - ({"storage_options": {"a": "b"}}, {"storage_options": {"x": "y"}}), - ], - ) - def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): - filepath = str(tmp_path / "test.csv") - - ds = ParquetDataSet(filepath=filepath, load_args=load_args, save_args=save_args) - - records = [r for r in caplog.records if r.levelname == "WARNING"] - expected_log_message = ( - f"Dropping 'storage_options' for {filepath}, " - f"please specify them under 'fs_args' or 'credentials'." - ) - assert records[0].getMessage() == expected_log_message - assert "storage_options" not in ds._save_args - assert "storage_options" not in ds._load_args - - def test_load_missing_file(self, parquet_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set ParquetDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - parquet_data_set.load() - - @pytest.mark.parametrize( - "filepath,instance_type,load_path", - [ - ("s3://bucket/file.parquet", S3FileSystem, "s3://bucket/file.parquet"), - ("file:///tmp/test.parquet", LocalFileSystem, "/tmp/test.parquet"), - ("/tmp/test.parquet", LocalFileSystem, "/tmp/test.parquet"), - ("gcs://bucket/file.parquet", GCSFileSystem, "gcs://bucket/file.parquet"), - ( - "https://example.com/file.parquet", - HTTPFileSystem, - "https://example.com/file.parquet", - ), - ], - ) - def test_protocol_usage(self, filepath, instance_type, load_path, mocker): - data_set = ParquetDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - mocker.patch.object(data_set._fs, "isdir", return_value=False) - mock_pandas_call = mocker.patch("pandas.read_parquet") - data_set.load() - assert mock_pandas_call.call_count == 1 - assert mock_pandas_call.call_args_list[0][0][0] == load_path - - @pytest.mark.parametrize( - "protocol,path", [("https://", "example.com/"), ("s3://", "bucket/")] - ) - def test_catalog_release(self, protocol, path, mocker): - filepath = protocol + path + FILENAME - fs_mock = mocker.patch("fsspec.filesystem").return_value - data_set = ParquetDataSet(filepath=filepath) - data_set.release() - if protocol != "https://": - filepath = path + FILENAME - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - def test_read_partitioned_file(self, mocker, tmp_path, dummy_dataframe): - """Test read partitioned parquet file from local directory.""" - pq_ds_mock = mocker.patch( - "pyarrow.parquet.ParquetDataset", wraps=pq.ParquetDataset - ) - dummy_dataframe.to_parquet(str(tmp_path), partition_cols=["col2"]) - data_set = ParquetDataSet(filepath=tmp_path.as_posix()) - - reloaded = data_set.load() - # Sort by columns because reading partitioned file results - # in different columns order - reloaded = reloaded.sort_index(axis=1) - # dtype for partition column is 'category' - assert_frame_equal( - dummy_dataframe, reloaded, check_dtype=False, check_categorical=False - ) - pq_ds_mock.assert_called_once() - - def test_write_to_dir(self, dummy_dataframe, tmp_path): - data_set = ParquetDataSet(filepath=tmp_path.as_posix()) - pattern = "Saving ParquetDataSet to a directory is not supported" - - with pytest.raises(DatasetError, match=pattern): - data_set.save(dummy_dataframe) - - def test_read_from_non_local_dir(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - fs_mock.isdir.return_value = True - pq_ds_mock = mocker.patch("pyarrow.parquet.ParquetDataset") - - data_set = ParquetDataSet(filepath="s3://bucket/dir") - - data_set.load() - fs_mock.isdir.assert_called_once() - assert not fs_mock.open.called - pq_ds_mock.assert_called_once_with("bucket/dir", filesystem=fs_mock) - pq_ds_mock().read().to_pandas.assert_called_once_with() - - def test_read_from_file(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - fs_mock.isdir.return_value = False - mocker.patch("pandas.read_parquet") - - data_set = ParquetDataSet(filepath="/tmp/test.parquet") - - data_set.load() - fs_mock.isdir.assert_called_once() - - def test_arg_partition_cols(self, dummy_dataframe, tmp_path): - data_set = ParquetDataSet( - filepath=(tmp_path / FILENAME).as_posix(), - save_args={"partition_cols": ["col2"]}, - ) - pattern = "does not support save argument 'partition_cols'" - - with pytest.raises(DatasetError, match=pattern): - data_set.save(dummy_dataframe) - - -class TestParquetDataSetVersioned: - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - ds = ParquetDataSet(filepath=FILENAME) - ds_versioned = ParquetDataSet( - filepath=FILENAME, version=Version(load_version, save_version) - ) - assert FILENAME in str(ds) - assert "version" not in str(ds) - - assert FILENAME in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "ParquetDataSet" in str(ds_versioned) - assert "ParquetDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - - def test_save_and_load(self, versioned_parquet_data_set, dummy_dataframe, mocker): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - mocker.patch( - "pyarrow.fs._ensure_filesystem", - return_value=PyFileSystem(FSSpecHandler(versioned_parquet_data_set._fs)), - ) - versioned_parquet_data_set.save(dummy_dataframe) - reloaded_df = versioned_parquet_data_set.load() - assert_frame_equal(dummy_dataframe, reloaded_df) - - def test_no_versions(self, versioned_parquet_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for ParquetDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_parquet_data_set.load() - - def test_exists(self, versioned_parquet_data_set, dummy_dataframe, mocker): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_parquet_data_set.exists() - mocker.patch( - "pyarrow.fs._ensure_filesystem", - return_value=PyFileSystem(FSSpecHandler(versioned_parquet_data_set._fs)), - ) - versioned_parquet_data_set.save(dummy_dataframe) - assert versioned_parquet_data_set.exists() - - def test_prevent_overwrite( - self, versioned_parquet_data_set, dummy_dataframe, mocker - ): - """Check the error when attempting to override the data set if the - corresponding parquet file for a given save version already exists.""" - mocker.patch( - "pyarrow.fs._ensure_filesystem", - return_value=PyFileSystem(FSSpecHandler(versioned_parquet_data_set._fs)), - ) - versioned_parquet_data_set.save(dummy_dataframe) - pattern = ( - r"Save path \'.+\' for ParquetDataSet\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_parquet_data_set.save(dummy_dataframe) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, - versioned_parquet_data_set, - load_version, - save_version, - dummy_dataframe, - mocker, - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for ParquetDataSet\(.+\)" - ) - mocker.patch( - "pyarrow.fs._ensure_filesystem", - return_value=PyFileSystem(FSSpecHandler(versioned_parquet_data_set._fs)), - ) - with pytest.warns(UserWarning, match=pattern): - versioned_parquet_data_set.save(dummy_dataframe) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - ParquetDataSet( - filepath="https://example.com/test.parquet", version=Version(None, None) - ) - - def test_versioning_existing_dataset( - self, parquet_data_set, versioned_parquet_data_set, dummy_dataframe - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - parquet_data_set.save(dummy_dataframe) - assert parquet_data_set.exists() - assert parquet_data_set._filepath == versioned_parquet_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_parquet_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_parquet_data_set.save(dummy_dataframe) - - # Remove non-versioned dataset and try again - Path(parquet_data_set._filepath.as_posix()).unlink() - versioned_parquet_data_set.save(dummy_dataframe) - assert versioned_parquet_data_set.exists() diff --git a/tests/extras/datasets/pandas/test_sql_dataset.py b/tests/extras/datasets/pandas/test_sql_dataset.py deleted file mode 100644 index d80ee12090..0000000000 --- a/tests/extras/datasets/pandas/test_sql_dataset.py +++ /dev/null @@ -1,425 +0,0 @@ -# pylint: disable=no-member -from pathlib import PosixPath -from unittest.mock import ANY - -import pandas as pd -import pytest -import sqlalchemy - -from kedro.extras.datasets.pandas import SQLQueryDataSet, SQLTableDataSet -from kedro.io import DatasetError - -TABLE_NAME = "table_a" -CONNECTION = "sqlite:///kedro.db" -SQL_QUERY = "SELECT * FROM table_a" -EXECUTION_OPTIONS = {"stream_results": True} -FAKE_CONN_STR = "some_sql://scott:tiger@localhost/foo" -ERROR_PREFIX = ( - r"A module\/driver is missing when connecting to your SQL server\.(.|\n)*" -) - - -@pytest.fixture(autouse=True) -def cleanup_engines(): - yield - SQLTableDataSet.engines = {} - SQLQueryDataSet.engines = {} - - -@pytest.fixture -def dummy_dataframe(): - return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) - - -@pytest.fixture -def sql_file(tmp_path: PosixPath): - file = tmp_path / "test.sql" - file.write_text(SQL_QUERY) - return file.as_posix() - - -@pytest.fixture(params=[{}]) -def table_data_set(request): - kwargs = {"table_name": TABLE_NAME, "credentials": {"con": CONNECTION}} - kwargs.update(request.param) - return SQLTableDataSet(**kwargs) - - -@pytest.fixture(params=[{}]) -def query_data_set(request): - kwargs = {"sql": SQL_QUERY, "credentials": {"con": CONNECTION}} - kwargs.update(request.param) - return SQLQueryDataSet(**kwargs) - - -@pytest.fixture(params=[{}]) -def query_file_data_set(request, sql_file): - kwargs = {"filepath": sql_file, "credentials": {"con": CONNECTION}} - kwargs.update(request.param) - return SQLQueryDataSet(**kwargs) - - -class TestSQLTableDataSet: - _unknown_conn = "mysql+unknown_module://scott:tiger@localhost/foo" - - @staticmethod - def _assert_sqlalchemy_called_once(*args): - _callable = sqlalchemy.engine.Engine.table_names - if args: - _callable.assert_called_once_with(*args) - else: - assert _callable.call_count == 1 - - def test_empty_table_name(self): - """Check the error when instantiating with an empty table""" - pattern = r"'table\_name' argument cannot be empty\." - with pytest.raises(DatasetError, match=pattern): - SQLTableDataSet(table_name="", credentials={"con": CONNECTION}) - - def test_empty_connection(self): - """Check the error when instantiating with an empty - connection string""" - pattern = ( - r"'con' argument cannot be empty\. " - r"Please provide a SQLAlchemy connection string\." - ) - with pytest.raises(DatasetError, match=pattern): - SQLTableDataSet(table_name=TABLE_NAME, credentials={"con": ""}) - - def test_driver_missing(self, mocker): - """Check the error when the sql driver is missing""" - mocker.patch( - "kedro.extras.datasets.pandas.sql_dataset.create_engine", - side_effect=ImportError("No module named 'mysqldb'"), - ) - with pytest.raises(DatasetError, match=ERROR_PREFIX + "mysqlclient"): - SQLTableDataSet(table_name=TABLE_NAME, credentials={"con": CONNECTION}) - - def test_unknown_sql(self): - """Check the error when unknown sql dialect is provided; - this means the error is raised on catalog creation, rather - than on load or save operation. - """ - pattern = r"The SQL dialect in your connection is not supported by SQLAlchemy" - with pytest.raises(DatasetError, match=pattern): - SQLTableDataSet(table_name=TABLE_NAME, credentials={"con": FAKE_CONN_STR}) - - def test_unknown_module(self, mocker): - """Test that if an unknown module/driver is encountered by SQLAlchemy - then the error should contain the original error message""" - mocker.patch( - "kedro.extras.datasets.pandas.sql_dataset.create_engine", - side_effect=ImportError("No module named 'unknown_module'"), - ) - pattern = ERROR_PREFIX + r"No module named \'unknown\_module\'" - with pytest.raises(DatasetError, match=pattern): - SQLTableDataSet(table_name=TABLE_NAME, credentials={"con": CONNECTION}) - - def test_str_representation_table(self, table_data_set): - """Test the data set instance string representation""" - str_repr = str(table_data_set) - assert ( - "SQLTableDataSet(load_args={}, save_args={'index': False}, " - f"table_name={TABLE_NAME})" in str_repr - ) - assert CONNECTION not in str(str_repr) - - def test_table_exists(self, mocker, table_data_set): - """Test `exists` method invocation""" - mocker.patch("sqlalchemy.engine.Engine.table_names") - assert not table_data_set.exists() - self._assert_sqlalchemy_called_once() - - @pytest.mark.parametrize( - "table_data_set", [{"load_args": {"schema": "ingested"}}], indirect=True - ) - def test_table_exists_schema(self, mocker, table_data_set): - """Test `exists` method invocation with DB schema provided""" - mocker.patch("sqlalchemy.engine.Engine.table_names") - assert not table_data_set.exists() - self._assert_sqlalchemy_called_once("ingested") - - def test_table_exists_mocked(self, mocker, table_data_set): - """Test `exists` method invocation with mocked list of tables""" - mocker.patch("sqlalchemy.engine.Engine.table_names", return_value=[TABLE_NAME]) - assert table_data_set.exists() - self._assert_sqlalchemy_called_once() - - def test_load_sql_params(self, mocker, table_data_set): - """Test `load` method invocation""" - mocker.patch("pandas.read_sql_table") - table_data_set.load() - pd.read_sql_table.assert_called_once_with( - table_name=TABLE_NAME, con=table_data_set.engines[CONNECTION] - ) - - def test_save_default_index(self, mocker, table_data_set, dummy_dataframe): - """Test `save` method invocation""" - mocker.patch.object(dummy_dataframe, "to_sql") - table_data_set.save(dummy_dataframe) - dummy_dataframe.to_sql.assert_called_once_with( - name=TABLE_NAME, con=table_data_set.engines[CONNECTION], index=False - ) - - @pytest.mark.parametrize( - "table_data_set", [{"save_args": {"index": True}}], indirect=True - ) - def test_save_overwrite_index(self, mocker, table_data_set, dummy_dataframe): - """Test writing DataFrame index as a column""" - mocker.patch.object(dummy_dataframe, "to_sql") - table_data_set.save(dummy_dataframe) - dummy_dataframe.to_sql.assert_called_once_with( - name=TABLE_NAME, con=table_data_set.engines[CONNECTION], index=True - ) - - @pytest.mark.parametrize( - "table_data_set", [{"save_args": {"name": "TABLE_B"}}], indirect=True - ) - def test_save_ignore_table_name_override( - self, mocker, table_data_set, dummy_dataframe - ): - """Test that putting the table name is `save_args` does not have any - effect""" - mocker.patch.object(dummy_dataframe, "to_sql") - table_data_set.save(dummy_dataframe) - dummy_dataframe.to_sql.assert_called_once_with( - name=TABLE_NAME, con=table_data_set.engines[CONNECTION], index=False - ) - - -class TestSQLTableDataSetSingleConnection: - def test_single_connection(self, dummy_dataframe, mocker): - """Test to make sure multiple instances use the same connection object.""" - mocker.patch("pandas.read_sql_table") - dummy_to_sql = mocker.patch.object(dummy_dataframe, "to_sql") - kwargs = {"table_name": TABLE_NAME, "credentials": {"con": CONNECTION}} - - first = SQLTableDataSet(**kwargs) - unique_connection = first.engines[CONNECTION] - datasets = [SQLTableDataSet(**kwargs) for _ in range(10)] - - for ds in datasets: - ds.save(dummy_dataframe) - engine = ds.engines[CONNECTION] - assert engine is unique_connection - - expected_call = mocker.call(name=TABLE_NAME, con=unique_connection, index=False) - dummy_to_sql.assert_has_calls([expected_call] * 10) - - for ds in datasets: - ds.load() - engine = ds.engines[CONNECTION] - assert engine is unique_connection - - def test_create_connection_only_once(self, mocker): - """Test that two datasets that need to connect to the same db - (but different tables, for example) only create a connection once. - """ - mock_engine = mocker.patch( - "kedro.extras.datasets.pandas.sql_dataset.create_engine" - ) - first = SQLTableDataSet(table_name=TABLE_NAME, credentials={"con": CONNECTION}) - assert len(first.engines) == 1 - - second = SQLTableDataSet( - table_name="other_table", credentials={"con": CONNECTION} - ) - assert len(second.engines) == 1 - assert len(first.engines) == 1 - - mock_engine.assert_called_once_with(CONNECTION) - - def test_multiple_connections(self, mocker): - """Test that two datasets that need to connect to different dbs - only create one connection per db. - """ - mock_engine = mocker.patch( - "kedro.extras.datasets.pandas.sql_dataset.create_engine" - ) - first = SQLTableDataSet(table_name=TABLE_NAME, credentials={"con": CONNECTION}) - assert len(first.engines) == 1 - - second_con = f"other_{CONNECTION}" - second = SQLTableDataSet(table_name=TABLE_NAME, credentials={"con": second_con}) - assert len(second.engines) == 2 - assert len(first.engines) == 2 - - expected_calls = [mocker.call(CONNECTION), mocker.call(second_con)] - assert mock_engine.call_args_list == expected_calls - - -class TestSQLQueryDataSet: - def test_empty_query_error(self): - """Check the error when instantiating with empty query or file""" - pattern = ( - r"'sql' and 'filepath' arguments cannot both be empty\." - r"Please provide a sql query or path to a sql query file\." - ) - with pytest.raises(DatasetError, match=pattern): - SQLQueryDataSet(sql="", filepath="", credentials={"con": CONNECTION}) - - def test_empty_con_error(self): - """Check the error when instantiating with empty connection string""" - pattern = ( - r"'con' argument cannot be empty\. Please provide " - r"a SQLAlchemy connection string" - ) - with pytest.raises(DatasetError, match=pattern): - SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": ""}) - - @pytest.mark.parametrize( - "query_data_set, has_execution_options", - [ - ({"execution_options": EXECUTION_OPTIONS}, True), - ({"execution_options": {}}, False), - ({}, False), - ], - indirect=["query_data_set"], - ) - def test_load(self, mocker, query_data_set, has_execution_options): - """Test `load` method invocation""" - mocker.patch("pandas.read_sql_query") - query_data_set.load() - - # Check that data was loaded with the expected query, connection string and - # execution options: - pd.read_sql_query.assert_called_once_with(sql=SQL_QUERY, con=ANY) - con_arg = pd.read_sql_query.call_args_list[0][1]["con"] - assert str(con_arg.url) == CONNECTION - assert len(con_arg.get_execution_options()) == bool(has_execution_options) - if has_execution_options: - assert con_arg.get_execution_options() == EXECUTION_OPTIONS - - @pytest.mark.parametrize( - "query_file_data_set, has_execution_options", - [ - ({"execution_options": EXECUTION_OPTIONS}, True), - ({"execution_options": {}}, False), - ({}, False), - ], - indirect=["query_file_data_set"], - ) - def test_load_query_file(self, mocker, query_file_data_set, has_execution_options): - """Test `load` method with a query file""" - mocker.patch("pandas.read_sql_query") - query_file_data_set.load() - - # Check that data was loaded with the expected query, connection string and - # execution options: - pd.read_sql_query.assert_called_once_with(sql=SQL_QUERY, con=ANY) - con_arg = pd.read_sql_query.call_args_list[0][1]["con"] - assert str(con_arg.url) == CONNECTION - assert len(con_arg.get_execution_options()) == bool(has_execution_options) - if has_execution_options: - assert con_arg.get_execution_options() == EXECUTION_OPTIONS - - def test_load_driver_missing(self, mocker): - """Test that if an unknown module/driver is encountered by SQLAlchemy - then the error should contain the original error message""" - _err = ImportError("No module named 'mysqldb'") - mocker.patch( - "kedro.extras.datasets.pandas.sql_dataset.create_engine", side_effect=_err - ) - with pytest.raises(DatasetError, match=ERROR_PREFIX + "mysqlclient"): - SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": CONNECTION}) - - def test_invalid_module(self, mocker): - """Test that if an unknown module/driver is encountered by SQLAlchemy - then the error should contain the original error message""" - _err = ImportError("Invalid module some_module") - mocker.patch( - "kedro.extras.datasets.pandas.sql_dataset.create_engine", side_effect=_err - ) - pattern = ERROR_PREFIX + r"Invalid module some\_module" - with pytest.raises(DatasetError, match=pattern): - SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": CONNECTION}) - - def test_load_unknown_module(self, mocker): - """Test that if an unknown module/driver is encountered by SQLAlchemy - then the error should contain the original error message""" - _err = ImportError("No module named 'unknown_module'") - mocker.patch( - "kedro.extras.datasets.pandas.sql_dataset.create_engine", side_effect=_err - ) - pattern = ERROR_PREFIX + r"No module named \'unknown\_module\'" - with pytest.raises(DatasetError, match=pattern): - SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": CONNECTION}) - - def test_load_unknown_sql(self): - """Check the error when unknown SQL dialect is provided - in the connection string""" - pattern = r"The SQL dialect in your connection is not supported by SQLAlchemy" - with pytest.raises(DatasetError, match=pattern): - SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": FAKE_CONN_STR}) - - def test_save_error(self, query_data_set, dummy_dataframe): - """Check the error when trying to save to the data set""" - pattern = r"'save' is not supported on SQLQueryDataSet" - with pytest.raises(DatasetError, match=pattern): - query_data_set.save(dummy_dataframe) - - def test_str_representation_sql(self, query_data_set, sql_file): - """Test the data set instance string representation""" - str_repr = str(query_data_set) - assert ( - "SQLQueryDataSet(execution_options={}, filepath=None, " - f"load_args={{}}, sql={SQL_QUERY})" in str_repr - ) - assert CONNECTION not in str_repr - assert sql_file not in str_repr - - def test_str_representation_filepath(self, query_file_data_set, sql_file): - """Test the data set instance string representation with filepath arg.""" - str_repr = str(query_file_data_set) - assert ( - f"SQLQueryDataSet(execution_options={{}}, filepath={str(sql_file)}, " - "load_args={}, sql=None)" in str_repr - ) - assert CONNECTION not in str_repr - assert SQL_QUERY not in str_repr - - def test_sql_and_filepath_args(self, sql_file): - """Test that an error is raised when both `sql` and `filepath` args are given.""" - pattern = ( - r"'sql' and 'filepath' arguments cannot both be provided." - r"Please only provide one." - ) - with pytest.raises(DatasetError, match=pattern): - SQLQueryDataSet(sql=SQL_QUERY, filepath=sql_file) - - def test_create_connection_only_once(self, mocker): - """Test that two datasets that need to connect to the same db (but different - tables and execution options, for example) only create a connection once. - """ - mock_engine = mocker.patch( - "kedro.extras.datasets.pandas.sql_dataset.create_engine" - ) - first = SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": CONNECTION}) - assert len(first.engines) == 1 - - # second engine has identical params to the first one - # => no new engine should be created - second = SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": CONNECTION}) - mock_engine.assert_called_once_with(CONNECTION) - assert second.engines == first.engines - assert len(first.engines) == 1 - - # third engine only differs by its query execution options - # => no new engine should be created - third = SQLQueryDataSet( - sql="a different query", - credentials={"con": CONNECTION}, - execution_options=EXECUTION_OPTIONS, - ) - assert mock_engine.call_count == 1 - assert third.engines == first.engines - assert len(first.engines) == 1 - - # fourth engine has a different connection string - # => a new engine has to be created - fourth = SQLQueryDataSet( - sql=SQL_QUERY, credentials={"con": "an other connection string"} - ) - assert mock_engine.call_count == 2 - assert fourth.engines == first.engines - assert len(first.engines) == 2 diff --git a/tests/extras/datasets/pandas/test_xml_dataset.py b/tests/extras/datasets/pandas/test_xml_dataset.py deleted file mode 100644 index 9dc8f47dc1..0000000000 --- a/tests/extras/datasets/pandas/test_xml_dataset.py +++ /dev/null @@ -1,241 +0,0 @@ -from pathlib import Path, PurePosixPath - -import pandas as pd -import pytest -from adlfs import AzureBlobFileSystem -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from pandas.testing import assert_frame_equal -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.pandas import XMLDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version - - -@pytest.fixture -def filepath_xml(tmp_path): - return (tmp_path / "test.xml").as_posix() - - -@pytest.fixture -def xml_data_set(filepath_xml, load_args, save_args, fs_args): - return XMLDataSet( - filepath=filepath_xml, - load_args=load_args, - save_args=save_args, - fs_args=fs_args, - ) - - -@pytest.fixture -def versioned_xml_data_set(filepath_xml, load_version, save_version): - return XMLDataSet( - filepath=filepath_xml, version=Version(load_version, save_version) - ) - - -@pytest.fixture -def dummy_dataframe(): - return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) - - -class TestXMLDataSet: - def test_save_and_load(self, xml_data_set, dummy_dataframe): - """Test saving and reloading the data set.""" - xml_data_set.save(dummy_dataframe) - reloaded = xml_data_set.load() - assert_frame_equal(dummy_dataframe, reloaded) - - def test_exists(self, xml_data_set, dummy_dataframe): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not xml_data_set.exists() - xml_data_set.save(dummy_dataframe) - assert xml_data_set.exists() - - @pytest.mark.parametrize( - "load_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_load_extra_params(self, xml_data_set, load_args): - """Test overriding the default load arguments.""" - for key, value in load_args.items(): - assert xml_data_set._load_args[key] == value - - @pytest.mark.parametrize( - "save_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_save_extra_params(self, xml_data_set, save_args): - """Test overriding the default save arguments.""" - for key, value in save_args.items(): - assert xml_data_set._save_args[key] == value - - @pytest.mark.parametrize( - "load_args,save_args", - [ - ({"storage_options": {"a": "b"}}, {}), - ({}, {"storage_options": {"a": "b"}}), - ({"storage_options": {"a": "b"}}, {"storage_options": {"x": "y"}}), - ], - ) - def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): - filepath = str(tmp_path / "test.csv") - - ds = XMLDataSet(filepath=filepath, load_args=load_args, save_args=save_args) - - records = [r for r in caplog.records if r.levelname == "WARNING"] - expected_log_message = ( - f"Dropping 'storage_options' for {filepath}, " - f"please specify them under 'fs_args' or 'credentials'." - ) - assert records[0].getMessage() == expected_log_message - assert "storage_options" not in ds._save_args - assert "storage_options" not in ds._load_args - - def test_load_missing_file(self, xml_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set XMLDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - xml_data_set.load() - - @pytest.mark.parametrize( - "filepath,instance_type,credentials,load_path", - [ - ("s3://bucket/file.xml", S3FileSystem, {}, "s3://bucket/file.xml"), - ("file:///tmp/test.xml", LocalFileSystem, {}, "/tmp/test.xml"), - ("/tmp/test.xml", LocalFileSystem, {}, "/tmp/test.xml"), - ("gcs://bucket/file.xml", GCSFileSystem, {}, "gcs://bucket/file.xml"), - ( - "https://example.com/file.xml", - HTTPFileSystem, - {}, - "https://example.com/file.xml", - ), - ( - "abfs://bucket/file.csv", - AzureBlobFileSystem, - {"account_name": "test", "account_key": "test"}, - "abfs://bucket/file.csv", - ), - ], - ) - def test_protocol_usage( - self, filepath, instance_type, credentials, load_path, mocker - ): - data_set = XMLDataSet(filepath=filepath, credentials=credentials) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - mock_pandas_call = mocker.patch("pandas.read_xml") - data_set.load() - assert mock_pandas_call.call_count == 1 - assert mock_pandas_call.call_args_list[0][0][0] == load_path - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.xml" - data_set = XMLDataSet(filepath=filepath) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - -class TestXMLDataSetVersioned: - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "test.xml" - ds = XMLDataSet(filepath=filepath) - ds_versioned = XMLDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - assert filepath in str(ds) - assert "version" not in str(ds) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "XMLDataSet" in str(ds_versioned) - assert "XMLDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - - def test_save_and_load(self, versioned_xml_data_set, dummy_dataframe): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_xml_data_set.save(dummy_dataframe) - reloaded_df = versioned_xml_data_set.load() - assert_frame_equal(dummy_dataframe, reloaded_df) - - def test_no_versions(self, versioned_xml_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for XMLDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_xml_data_set.load() - - def test_exists(self, versioned_xml_data_set, dummy_dataframe): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_xml_data_set.exists() - versioned_xml_data_set.save(dummy_dataframe) - assert versioned_xml_data_set.exists() - - def test_prevent_overwrite(self, versioned_xml_data_set, dummy_dataframe): - """Check the error when attempting to override the data set if the - corresponding hdf file for a given save version already exists.""" - versioned_xml_data_set.save(dummy_dataframe) - pattern = ( - r"Save path \'.+\' for XMLDataSet\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_xml_data_set.save(dummy_dataframe) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_xml_data_set, load_version, save_version, dummy_dataframe - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match " - rf"load version '{load_version}' for XMLDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_xml_data_set.save(dummy_dataframe) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - XMLDataSet( - filepath="https://example.com/file.xml", version=Version(None, None) - ) - - def test_versioning_existing_dataset( - self, xml_data_set, versioned_xml_data_set, dummy_dataframe - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - xml_data_set.save(dummy_dataframe) - assert xml_data_set.exists() - assert xml_data_set._filepath == versioned_xml_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_xml_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_xml_data_set.save(dummy_dataframe) - - # Remove non-versioned dataset and try again - Path(xml_data_set._filepath.as_posix()).unlink() - versioned_xml_data_set.save(dummy_dataframe) - assert versioned_xml_data_set.exists() diff --git a/tests/extras/datasets/pickle/__init__.py b/tests/extras/datasets/pickle/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/pickle/test_pickle_dataset.py b/tests/extras/datasets/pickle/test_pickle_dataset.py deleted file mode 100644 index 65f7495a06..0000000000 --- a/tests/extras/datasets/pickle/test_pickle_dataset.py +++ /dev/null @@ -1,269 +0,0 @@ -import pickle -from pathlib import Path, PurePosixPath - -import pandas as pd -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from pandas.testing import assert_frame_equal -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.pickle import PickleDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version - - -@pytest.fixture -def filepath_pickle(tmp_path): - return (tmp_path / "test.pkl").as_posix() - - -@pytest.fixture(params=["pickle"]) -def backend(request): - return request.param - - -@pytest.fixture -def pickle_data_set(filepath_pickle, backend, load_args, save_args, fs_args): - return PickleDataSet( - filepath=filepath_pickle, - backend=backend, - load_args=load_args, - save_args=save_args, - fs_args=fs_args, - ) - - -@pytest.fixture -def versioned_pickle_data_set(filepath_pickle, load_version, save_version): - return PickleDataSet( - filepath=filepath_pickle, version=Version(load_version, save_version) - ) - - -@pytest.fixture -def dummy_dataframe(): - return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) - - -class TestPickleDataSet: - @pytest.mark.parametrize( - "backend,load_args,save_args", - [ - ("pickle", None, None), - ("joblib", None, None), - ("dill", None, None), - ("compress_pickle", {"compression": "lz4"}, {"compression": "lz4"}), - ], - indirect=True, - ) - def test_save_and_load(self, pickle_data_set, dummy_dataframe): - """Test saving and reloading the data set.""" - pickle_data_set.save(dummy_dataframe) - reloaded = pickle_data_set.load() - assert_frame_equal(dummy_dataframe, reloaded) - assert pickle_data_set._fs_open_args_load == {} - assert pickle_data_set._fs_open_args_save == {"mode": "wb"} - - def test_exists(self, pickle_data_set, dummy_dataframe): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not pickle_data_set.exists() - pickle_data_set.save(dummy_dataframe) - assert pickle_data_set.exists() - - @pytest.mark.parametrize( - "load_args", [{"k1": "v1", "errors": "strict"}], indirect=True - ) - def test_load_extra_params(self, pickle_data_set, load_args): - """Test overriding the default load arguments.""" - for key, value in load_args.items(): - assert pickle_data_set._load_args[key] == value - - @pytest.mark.parametrize("save_args", [{"k1": "v1", "protocol": 2}], indirect=True) - def test_save_extra_params(self, pickle_data_set, save_args): - """Test overriding the default save arguments.""" - for key, value in save_args.items(): - assert pickle_data_set._save_args[key] == value - - @pytest.mark.parametrize( - "fs_args", - [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], - indirect=True, - ) - def test_open_extra_args(self, pickle_data_set, fs_args): - assert pickle_data_set._fs_open_args_load == fs_args["open_args_load"] - assert pickle_data_set._fs_open_args_save == {"mode": "wb"} # default unchanged - - def test_load_missing_file(self, pickle_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set PickleDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - pickle_data_set.load() - - @pytest.mark.parametrize( - "filepath,instance_type", - [ - ("s3://bucket/file.pkl", S3FileSystem), - ("file:///tmp/test.pkl", LocalFileSystem), - ("/tmp/test.pkl", LocalFileSystem), - ("gcs://bucket/file.pkl", GCSFileSystem), - ("https://example.com/file.pkl", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, filepath, instance_type): - data_set = PickleDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.pkl" - data_set = PickleDataSet(filepath=filepath) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - def test_unserialisable_data(self, pickle_data_set, dummy_dataframe, mocker): - mocker.patch("pickle.dump", side_effect=pickle.PickleError) - pattern = r".+ was not serialised due to:.*" - - with pytest.raises(DatasetError, match=pattern): - pickle_data_set.save(dummy_dataframe) - - def test_invalid_backend(self, mocker): - pattern = ( - r"Selected backend 'invalid' should satisfy the pickle interface. " - r"Missing one of 'load' and 'dump' on the backend." - ) - mocker.patch( - "kedro.extras.datasets.pickle.pickle_dataset.importlib.import_module", - return_value=object, - ) - with pytest.raises(ValueError, match=pattern): - PickleDataSet(filepath="test.pkl", backend="invalid") - - def test_no_backend(self, mocker): - pattern = ( - r"Selected backend 'fake.backend.does.not.exist' could not be imported. " - r"Make sure it is installed and importable." - ) - mocker.patch( - "kedro.extras.datasets.pickle.pickle_dataset.importlib.import_module", - side_effect=ImportError, - ) - with pytest.raises(ImportError, match=pattern): - PickleDataSet(filepath="test.pkl", backend="fake.backend.does.not.exist") - - def test_copy(self, pickle_data_set): - pickle_data_set_copy = pickle_data_set._copy() - assert pickle_data_set_copy is not pickle_data_set - assert pickle_data_set_copy._describe() == pickle_data_set._describe() - - -class TestPickleDataSetVersioned: - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "test.pkl" - ds = PickleDataSet(filepath=filepath) - ds_versioned = PickleDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - assert filepath in str(ds) - assert "version" not in str(ds) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "PickleDataSet" in str(ds_versioned) - assert "PickleDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - assert "backend" in str(ds_versioned) - assert "backend" in str(ds) - - def test_save_and_load(self, versioned_pickle_data_set, dummy_dataframe): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_pickle_data_set.save(dummy_dataframe) - reloaded_df = versioned_pickle_data_set.load() - assert_frame_equal(dummy_dataframe, reloaded_df) - - def test_no_versions(self, versioned_pickle_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for PickleDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_pickle_data_set.load() - - def test_exists(self, versioned_pickle_data_set, dummy_dataframe): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_pickle_data_set.exists() - versioned_pickle_data_set.save(dummy_dataframe) - assert versioned_pickle_data_set.exists() - - def test_prevent_overwrite(self, versioned_pickle_data_set, dummy_dataframe): - """Check the error when attempting to override the data set if the - corresponding Pickle file for a given save version already exists.""" - versioned_pickle_data_set.save(dummy_dataframe) - pattern = ( - r"Save path \'.+\' for PickleDataSet\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_pickle_data_set.save(dummy_dataframe) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_pickle_data_set, load_version, save_version, dummy_dataframe - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for PickleDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_pickle_data_set.save(dummy_dataframe) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - PickleDataSet( - filepath="https://example.com/file.pkl", version=Version(None, None) - ) - - def test_versioning_existing_dataset( - self, pickle_data_set, versioned_pickle_data_set, dummy_dataframe - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - pickle_data_set.save(dummy_dataframe) - assert pickle_data_set.exists() - assert pickle_data_set._filepath == versioned_pickle_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_pickle_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_pickle_data_set.save(dummy_dataframe) - - # Remove non-versioned dataset and try again - Path(pickle_data_set._filepath.as_posix()).unlink() - versioned_pickle_data_set.save(dummy_dataframe) - assert versioned_pickle_data_set.exists() - - def test_copy(self, versioned_pickle_data_set): - pickle_data_set_copy = versioned_pickle_data_set._copy() - assert pickle_data_set_copy is not versioned_pickle_data_set - assert pickle_data_set_copy._describe() == versioned_pickle_data_set._describe() diff --git a/tests/extras/datasets/pillow/__init__.py b/tests/extras/datasets/pillow/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/pillow/data/image.png b/tests/extras/datasets/pillow/data/image.png deleted file mode 100644 index 4147f3ef7a..0000000000 Binary files a/tests/extras/datasets/pillow/data/image.png and /dev/null differ diff --git a/tests/extras/datasets/pillow/test_image_dataset.py b/tests/extras/datasets/pillow/test_image_dataset.py deleted file mode 100644 index d3cb450989..0000000000 --- a/tests/extras/datasets/pillow/test_image_dataset.py +++ /dev/null @@ -1,231 +0,0 @@ -from pathlib import Path, PurePosixPath -from time import sleep - -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from PIL import Image, ImageChops -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.pillow import ImageDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version, generate_timestamp - - -@pytest.fixture -def filepath_png(tmp_path): - return (tmp_path / "test.png").as_posix() - - -@pytest.fixture -def image_dataset(filepath_png, save_args, fs_args): - return ImageDataSet(filepath=filepath_png, save_args=save_args, fs_args=fs_args) - - -@pytest.fixture -def versioned_image_dataset(filepath_png, load_version, save_version): - return ImageDataSet( - filepath=filepath_png, version=Version(load_version, save_version) - ) - - -@pytest.fixture(scope="module") -def image_object(): - filepath = str(Path(__file__).parent / "data/image.png") - return Image.open(filepath).copy() - - -def images_equal(image_1, image_2): - diff = ImageChops.difference(image_1, image_2) - return not diff.getbbox() - - -class TestImageDataSet: - def test_save_and_load(self, image_dataset, image_object): - """Test saving and reloading the data set.""" - image_dataset.save(image_object) - reloaded_image = image_dataset.load() - assert images_equal(image_object, reloaded_image) - assert image_dataset._fs_open_args_save == {"mode": "wb"} - - def test_exists(self, image_dataset, image_object): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not image_dataset.exists() - image_dataset.save(image_object) - assert image_dataset.exists() - - @pytest.mark.parametrize( - "save_args", [{"format": "png", "index": "value"}], indirect=True - ) - def test_load_extra_params(self, image_dataset, save_args): - """Test overriding the default load arguments.""" - for key, value in save_args.items(): - assert image_dataset._save_args[key] == value - - @pytest.mark.parametrize( - "fs_args", - [ - { - "open_args_load": {"mode": "r", "compression": "gzip"}, - "open_args_save": {"fs_save": "fs_save"}, - } - ], - indirect=True, - ) - def test_open_extra_args(self, image_dataset, fs_args): - assert image_dataset._fs_open_args_load == fs_args["open_args_load"] - expected_save_fs_args = {"mode": "wb"} # default - expected_save_fs_args.update(fs_args["open_args_save"]) - assert image_dataset._fs_open_args_save == expected_save_fs_args - - def test_load_missing_file(self, image_dataset): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set ImageDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - image_dataset.load() - - @pytest.mark.parametrize( - "filepath,instance_type", - [ - ("s3://bucket/file.png", S3FileSystem), - ("file:///tmp/test.png", LocalFileSystem), - ("/tmp/test.png", LocalFileSystem), - ("https://example.com/file.png", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, filepath, instance_type): - data_set = ImageDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.png" - data_set = ImageDataSet(filepath=filepath) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - -class TestImageDataSetVersioned: - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "/tmp/test.png" - ds = ImageDataSet(filepath=filepath) - ds_versioned = ImageDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - assert filepath in str(ds) - assert filepath in str(ds_versioned) - - assert "version" not in str(ds) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "ImageDataSet" in str(ds_versioned) - assert "ImageDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - - def test_save_and_load(self, versioned_image_dataset, image_object): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_image_dataset.save(image_object) - reloaded_image = versioned_image_dataset.load() - assert images_equal(image_object, reloaded_image) - - def test_multiple_loads(self, versioned_image_dataset, image_object, filepath_png): - """Test that if a new version is created mid-run, by an - external system, it won't be loaded in the current run.""" - versioned_image_dataset.save(image_object) - v1 = versioned_image_dataset.resolve_load_version() - - # Sometimes for some reason `v1 == v_new` on Windows. - # `sleep()` was added to fix this. - sleep(0.5) - # force-drop a newer version into the same location - v_new = generate_timestamp() - ImageDataSet(filepath=filepath_png, version=Version(v_new, v_new)).save( - image_object - ) - - v2 = versioned_image_dataset.resolve_load_version() - - assert v2 == v1 # v2 should not be v_new! - ds_new = ImageDataSet(filepath=filepath_png, version=Version(None, None)) - assert ( - ds_new.resolve_load_version() == v_new - ) # new version is discoverable by a new instance - - def test_no_versions(self, versioned_image_dataset): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for ImageDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_image_dataset.load() - - def test_exists(self, versioned_image_dataset, image_object): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_image_dataset.exists() - versioned_image_dataset.save(image_object) - assert versioned_image_dataset.exists() - - def test_prevent_overwrite(self, versioned_image_dataset, image_object): - """Check the error when attempting to override the data set if the - corresponding image file for a given save version already exists.""" - versioned_image_dataset.save(image_object) - pattern = ( - r"Save path \'.+\' for ImageDataSet\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_image_dataset.save(image_object) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_image_dataset, load_version, save_version, image_object - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for ImageDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_image_dataset.save(image_object) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - ImageDataSet( - filepath="https://example.com/file.png", version=Version(None, None) - ) - - def test_versioning_existing_dataset( - self, image_dataset, versioned_image_dataset, image_object - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - image_dataset.save(image_object) - assert image_dataset.exists() - assert image_dataset._filepath == versioned_image_dataset._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_image_dataset._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_image_dataset.save(image_object) - - # Remove non-versioned dataset and try again - Path(image_dataset._filepath.as_posix()).unlink() - versioned_image_dataset.save(image_object) - assert versioned_image_dataset.exists() diff --git a/tests/extras/datasets/plotly/__init__.py b/tests/extras/datasets/plotly/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/plotly/test_json_dataset.py b/tests/extras/datasets/plotly/test_json_dataset.py deleted file mode 100644 index 552d34bb27..0000000000 --- a/tests/extras/datasets/plotly/test_json_dataset.py +++ /dev/null @@ -1,101 +0,0 @@ -from pathlib import PurePosixPath - -import plotly.express as px -import pytest -from adlfs import AzureBlobFileSystem -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.plotly import JSONDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER - - -@pytest.fixture -def filepath_json(tmp_path): - return (tmp_path / "test.json").as_posix() - - -@pytest.fixture -def json_data_set(filepath_json, load_args, save_args, fs_args): - return JSONDataSet( - filepath=filepath_json, - load_args=load_args, - save_args=save_args, - fs_args=fs_args, - ) - - -@pytest.fixture -def dummy_plot(): - return px.scatter(x=[1, 2, 3], y=[1, 3, 2], title="Test") - - -class TestJSONDataSet: - def test_save_and_load(self, json_data_set, dummy_plot): - """Test saving and reloading the data set.""" - json_data_set.save(dummy_plot) - reloaded = json_data_set.load() - assert dummy_plot == reloaded - assert json_data_set._fs_open_args_load == {} - assert json_data_set._fs_open_args_save == {"mode": "w"} - - def test_exists(self, json_data_set, dummy_plot): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not json_data_set.exists() - json_data_set.save(dummy_plot) - assert json_data_set.exists() - - def test_load_missing_file(self, json_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set JSONDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - json_data_set.load() - - @pytest.mark.parametrize("save_args", [{"pretty": True}]) - def test_save_extra_params(self, json_data_set, save_args): - """Test overriding default save args""" - for k, v in save_args.items(): - assert json_data_set._save_args[k] == v - - @pytest.mark.parametrize( - "load_args", [{"output_type": "FigureWidget", "skip_invalid": True}] - ) - def test_load_extra_params(self, json_data_set, load_args): - """Test overriding default save args""" - for k, v in load_args.items(): - assert json_data_set._load_args[k] == v - - @pytest.mark.parametrize( - "filepath,instance_type,credentials", - [ - ("s3://bucket/file.json", S3FileSystem, {}), - ("file:///tmp/test.json", LocalFileSystem, {}), - ("/tmp/test.json", LocalFileSystem, {}), - ("gcs://bucket/file.json", GCSFileSystem, {}), - ("https://example.com/file.json", HTTPFileSystem, {}), - ( - "abfs://bucket/file.csv", - AzureBlobFileSystem, - {"account_name": "test", "account_key": "test"}, - ), - ], - ) - def test_protocol_usage(self, filepath, instance_type, credentials): - data_set = JSONDataSet(filepath=filepath, credentials=credentials) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.json" - data_set = JSONDataSet(filepath=filepath) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) diff --git a/tests/extras/datasets/plotly/test_plotly_dataset.py b/tests/extras/datasets/plotly/test_plotly_dataset.py deleted file mode 100644 index 042a414905..0000000000 --- a/tests/extras/datasets/plotly/test_plotly_dataset.py +++ /dev/null @@ -1,108 +0,0 @@ -from pathlib import PurePosixPath - -import pandas as pd -import pytest -from adlfs import AzureBlobFileSystem -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from plotly import graph_objects -from plotly.graph_objs import Scatter -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.plotly import PlotlyDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER - - -@pytest.fixture -def filepath_json(tmp_path): - return (tmp_path / "test.json").as_posix() - - -@pytest.fixture -def plotly_data_set(filepath_json, load_args, save_args, fs_args, plotly_args): - return PlotlyDataSet( - filepath=filepath_json, - load_args=load_args, - save_args=save_args, - fs_args=fs_args, - plotly_args=plotly_args, - ) - - -@pytest.fixture -def plotly_args(): - return { - "fig": {"orientation": "h", "x": "col1", "y": "col2"}, - "layout": {"title": "Test", "xaxis_title": "x", "yaxis_title": "y"}, - "type": "scatter", - } - - -@pytest.fixture -def dummy_dataframe(): - return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) - - -class TestPlotlyDataSet: - def test_save_and_load(self, plotly_data_set, dummy_dataframe): - """Test saving and reloading the data set.""" - plotly_data_set.save(dummy_dataframe) - reloaded = plotly_data_set.load() - assert isinstance(reloaded, graph_objects.Figure) - assert "Test" in str(reloaded["layout"]["title"]) - assert isinstance(reloaded["data"][0], Scatter) - - def test_exists(self, plotly_data_set, dummy_dataframe): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not plotly_data_set.exists() - plotly_data_set.save(dummy_dataframe) - assert plotly_data_set.exists() - - def test_load_missing_file(self, plotly_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set PlotlyDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - plotly_data_set.load() - - @pytest.mark.parametrize( - "filepath,instance_type,credentials", - [ - ("s3://bucket/file.json", S3FileSystem, {}), - ("file:///tmp/test.json", LocalFileSystem, {}), - ("/tmp/test.json", LocalFileSystem, {}), - ("gcs://bucket/file.json", GCSFileSystem, {}), - ("https://example.com/file.json", HTTPFileSystem, {}), - ( - "abfs://bucket/file.csv", - AzureBlobFileSystem, - {"account_name": "test", "account_key": "test"}, - ), - ], - ) - def test_protocol_usage(self, filepath, instance_type, credentials, plotly_args): - data_set = PlotlyDataSet( - filepath=filepath, credentials=credentials, plotly_args=plotly_args - ) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker, plotly_args): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.json" - data_set = PlotlyDataSet(filepath=filepath, plotly_args=plotly_args) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - def test_fail_if_invalid_plotly_args_provided(self): - plotly_args = [] - filepath = "test.json" - data_set = PlotlyDataSet(filepath=filepath, plotly_args=plotly_args) - with pytest.raises(DatasetError): - data_set.save(dummy_dataframe) diff --git a/tests/extras/datasets/redis/__init__.py b/tests/extras/datasets/redis/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/redis/test_redis_dataset.py b/tests/extras/datasets/redis/test_redis_dataset.py deleted file mode 100644 index bd4c4da9fa..0000000000 --- a/tests/extras/datasets/redis/test_redis_dataset.py +++ /dev/null @@ -1,165 +0,0 @@ -"""Tests ``PickleDataSet``.""" - -import importlib -import pickle - -import numpy as np -import pandas as pd -import pytest -import redis -from pandas.testing import assert_frame_equal - -from kedro.extras.datasets.redis import PickleDataSet -from kedro.io import DatasetError - - -@pytest.fixture(params=["pickle"]) -def backend(request): - return request.param - - -@pytest.fixture(params=["key"]) -def key(request): - return request.param - - -@pytest.fixture -def redis_args(): - return { - "from_url_args": {"arg1": "1", "arg2": "2", "url": "redis://127.0.0.1:6379"} - } - - -@pytest.fixture -def dummy_object(): - """Test data for saving.""" - return pd.DataFrame(np.random.random((3, 3)), columns=["a", "b", "c"]) - - -@pytest.fixture -def serialised_dummy_object(backend, dummy_object, save_args): - """Serialise test data.""" - imported_backend = importlib.import_module(backend) - save_args = save_args or {} - return imported_backend.dumps(dummy_object, **save_args) - - -@pytest.fixture -def pickle_data_set(mocker, key, backend, load_args, save_args, redis_args): - mocker.patch( - "redis.StrictRedis.from_url", return_value=redis.Redis.from_url("redis://") - ) - return PickleDataSet( - key=key, - backend=backend, - load_args=load_args, - save_args=save_args, - redis_args=redis_args, - ) - - -class TestPickleDataSet: - @pytest.mark.parametrize( - "key,backend,load_args,save_args", - [ - ("a", "pickle", None, None), - (1, "dill", None, None), - ("key", "compress_pickle", {"compression": "lz4"}, {"compression": "lz4"}), - ], - indirect=True, - ) - def test_save_and_load( - self, - pickle_data_set, - mocker, - dummy_object, - serialised_dummy_object, - key, - ): - """Test saving and reloading the data set.""" - set_mocker = mocker.patch("redis.StrictRedis.set") - get_mocker = mocker.patch( - "redis.StrictRedis.get", return_value=serialised_dummy_object - ) - pickle_data_set.save(dummy_object) - mocker.patch("redis.StrictRedis.exists", return_value=True) - loaded_dummy_object = pickle_data_set.load() - set_mocker.assert_called_once_with( - key, - serialised_dummy_object, - ) - get_mocker.assert_called_once_with(key) - assert_frame_equal(loaded_dummy_object, dummy_object) - - def test_exists(self, mocker, pickle_data_set, dummy_object, key): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - mocker.patch("redis.StrictRedis.exists", return_value=False) - assert not pickle_data_set.exists() - mocker.patch("redis.StrictRedis.set") - pickle_data_set.save(dummy_object) - exists_mocker = mocker.patch("redis.StrictRedis.exists", return_value=True) - assert pickle_data_set.exists() - exists_mocker.assert_called_once_with(key) - - def test_exists_raises_error(self, pickle_data_set): - """Check the error when trying to assert existence with no redis server.""" - pattern = r"The existence of key " - with pytest.raises(DatasetError, match=pattern): - pickle_data_set.exists() - - @pytest.mark.parametrize( - "load_args", [{"k1": "v1", "errors": "strict"}], indirect=True - ) - def test_load_extra_params(self, pickle_data_set, load_args): - """Test overriding the default load arguments.""" - for key, value in load_args.items(): - assert pickle_data_set._load_args[key] == value - - @pytest.mark.parametrize("save_args", [{"k1": "v1", "protocol": 2}], indirect=True) - def test_save_extra_params(self, pickle_data_set, save_args): - """Test overriding the default save arguments.""" - for key, value in save_args.items(): - assert pickle_data_set._save_args[key] == value - - def test_redis_extra_args(self, pickle_data_set, redis_args): - assert pickle_data_set._redis_from_url_args == redis_args["from_url_args"] - assert pickle_data_set._redis_set_args == {} # default unchanged - - def test_load_missing_key(self, mocker, pickle_data_set): - """Check the error when trying to load missing file.""" - pattern = r"The provided key " - mocker.patch("redis.StrictRedis.exists", return_value=False) - with pytest.raises(DatasetError, match=pattern): - pickle_data_set.load() - - def test_unserialisable_data(self, pickle_data_set, dummy_object, mocker): - mocker.patch("pickle.dumps", side_effect=pickle.PickleError) - pattern = r".+ was not serialised due to:.*" - - with pytest.raises(DatasetError, match=pattern): - pickle_data_set.save(dummy_object) - - def test_invalid_backend(self, mocker): - pattern = ( - r"Selected backend 'invalid' should satisfy the pickle interface. " - r"Missing one of 'loads' and 'dumps' on the backend." - ) - mocker.patch( - "kedro.extras.datasets.pickle.pickle_dataset.importlib.import_module", - return_value=object, - ) - with pytest.raises(ValueError, match=pattern): - PickleDataSet(key="key", backend="invalid") - - def test_no_backend(self, mocker): - pattern = ( - r"Selected backend 'fake.backend.does.not.exist' could not be imported. " - r"Make sure it is installed and importable." - ) - mocker.patch( - "kedro.extras.datasets.pickle.pickle_dataset.importlib.import_module", - side_effect=ImportError, - ) - with pytest.raises(ImportError, match=pattern): - PickleDataSet("key", backend="fake.backend.does.not.exist") diff --git a/tests/extras/datasets/spark/__init__.py b/tests/extras/datasets/spark/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/spark/conftest.py b/tests/extras/datasets/spark/conftest.py deleted file mode 100644 index 3e3ae1c544..0000000000 --- a/tests/extras/datasets/spark/conftest.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -This file contains the fixtures that are reusable by any tests within -this directory. You don't need to import the fixtures as pytest will -discover them automatically. More info here: -https://docs.pytest.org/en/latest/fixture.html -""" -import pytest -from delta import configure_spark_with_delta_pip -from filelock import FileLock - -try: - from pyspark.sql import SparkSession -except ImportError: # pragma: no cover - pass # this is only for test discovery to succeed on Python 3.8, 3.9 - - -def _setup_spark_session(): - return configure_spark_with_delta_pip( - SparkSession.builder.appName("MyApp") - .master("local[*]") - .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") - .config( - "spark.sql.catalog.spark_catalog", - "org.apache.spark.sql.delta.catalog.DeltaCatalog", - ) - ).getOrCreate() - - -@pytest.fixture(scope="module", autouse=True) -def spark_session(tmp_path_factory): - # When running these spark tests with pytest-xdist, we need to make sure - # that the spark session setup on each test process don't interfere with each other. - # Therefore, we block the process during the spark session setup. - # Locking procedure comes from pytest-xdist's own recommendation: - # https://github.com/pytest-dev/pytest-xdist#making-session-scoped-fixtures-execute-only-once - root_tmp_dir = tmp_path_factory.getbasetemp().parent - lock = root_tmp_dir / "semaphore.lock" - with FileLock(lock): - spark = _setup_spark_session() - yield spark - spark.stop() diff --git a/tests/extras/datasets/spark/data/test.parquet b/tests/extras/datasets/spark/data/test.parquet deleted file mode 100644 index 024ef2d9f9..0000000000 Binary files a/tests/extras/datasets/spark/data/test.parquet and /dev/null differ diff --git a/tests/extras/datasets/spark/test_deltatable_dataset.py b/tests/extras/datasets/spark/test_deltatable_dataset.py deleted file mode 100644 index a0ad5bc9d9..0000000000 --- a/tests/extras/datasets/spark/test_deltatable_dataset.py +++ /dev/null @@ -1,100 +0,0 @@ -import pytest -from delta import DeltaTable -from pyspark import __version__ -from pyspark.sql import SparkSession -from pyspark.sql.types import IntegerType, StringType, StructField, StructType -from pyspark.sql.utils import AnalysisException -from semver import VersionInfo - -from kedro.extras.datasets.spark import DeltaTableDataSet, SparkDataSet -from kedro.io import DataCatalog, DatasetError -from kedro.pipeline import node -from kedro.pipeline.modular_pipeline import pipeline as modular_pipeline -from kedro.runner import ParallelRunner - -SPARK_VERSION = VersionInfo.parse(__version__) - - -@pytest.fixture -def sample_spark_df(): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] - - return SparkSession.builder.getOrCreate().createDataFrame(data, schema) - - -class TestDeltaTableDataSet: - def test_load(self, tmp_path, sample_spark_df): - filepath = (tmp_path / "test_data").as_posix() - spark_delta_ds = SparkDataSet(filepath=filepath, file_format="delta") - spark_delta_ds.save(sample_spark_df) - loaded_with_spark = spark_delta_ds.load() - assert loaded_with_spark.exceptAll(sample_spark_df).count() == 0 - - delta_ds = DeltaTableDataSet(filepath=filepath) - delta_table = delta_ds.load() - - assert isinstance(delta_table, DeltaTable) - loaded_with_deltalake = delta_table.toDF() - assert loaded_with_deltalake.exceptAll(loaded_with_spark).count() == 0 - - def test_save(self, tmp_path, sample_spark_df): - filepath = (tmp_path / "test_data").as_posix() - delta_ds = DeltaTableDataSet(filepath=filepath) - assert not delta_ds.exists() - - pattern = "DeltaTableDataSet is a read only dataset type" - with pytest.raises(DatasetError, match=pattern): - delta_ds.save(sample_spark_df) - - # check that indeed nothing is written - assert not delta_ds.exists() - - def test_exists(self, tmp_path, sample_spark_df): - filepath = (tmp_path / "test_data").as_posix() - delta_ds = DeltaTableDataSet(filepath=filepath) - - assert not delta_ds.exists() - - spark_delta_ds = SparkDataSet(filepath=filepath, file_format="delta") - spark_delta_ds.save(sample_spark_df) - - assert delta_ds.exists() - - def test_exists_raises_error(self, mocker): - delta_ds = DeltaTableDataSet(filepath="") - if SPARK_VERSION.match(">=3.4.0"): - mocker.patch.object( - delta_ds, "_get_spark", side_effect=AnalysisException("Other Exception") - ) - else: - mocker.patch.object( - delta_ds, - "_get_spark", - side_effect=AnalysisException("Other Exception", []), - ) - with pytest.raises(DatasetError, match="Other Exception"): - delta_ds.exists() - - @pytest.mark.parametrize("is_async", [False, True]) - def test_parallel_runner(self, is_async): - """Test ParallelRunner with SparkDataSet fails.""" - - def no_output(x): - _ = x + 1 # pragma: no cover - - delta_ds = DeltaTableDataSet(filepath="") - catalog = DataCatalog(data_sets={"delta_in": delta_ds}) - pipeline = modular_pipeline([node(no_output, "delta_in", None)]) - pattern = ( - r"The following data sets cannot be used with " - r"multiprocessing: \['delta_in'\]" - ) - with pytest.raises(AttributeError, match=pattern): - ParallelRunner(is_async=is_async).run(pipeline, catalog) diff --git a/tests/extras/datasets/spark/test_memory_dataset.py b/tests/extras/datasets/spark/test_memory_dataset.py deleted file mode 100644 index d678f42d63..0000000000 --- a/tests/extras/datasets/spark/test_memory_dataset.py +++ /dev/null @@ -1,67 +0,0 @@ -import pytest -from pyspark.sql import DataFrame as SparkDataFrame -from pyspark.sql import SparkSession -from pyspark.sql.functions import col, when - -from kedro.io import MemoryDataset - - -def _update_spark_df(data, idx, jdx, value): - session = SparkSession.builder.getOrCreate() - data = session.createDataFrame(data.rdd.zipWithIndex()).select( - col("_1.*"), col("_2").alias("__id") - ) - cname = data.columns[idx] - return data.withColumn( - cname, when(col("__id") == jdx, value).otherwise(col(cname)) - ).drop("__id") - - -def _check_equals(data1, data2): - if isinstance(data1, SparkDataFrame) and isinstance(data2, SparkDataFrame): - return data1.toPandas().equals(data2.toPandas()) - return False # pragma: no cover - - -@pytest.fixture -def spark_data_frame(spark_session): - return spark_session.createDataFrame( - [(1, 4, 5), (2, 5, 6)], ["col1", "col2", "col3"] - ) - - -@pytest.fixture -def memory_dataset(spark_data_frame): - return MemoryDataset(data=spark_data_frame) - - -def test_load_modify_original_data(memory_dataset, spark_data_frame): - """Check that the data set object is not updated when the original - SparkDataFrame is changed.""" - spark_data_frame = _update_spark_df(spark_data_frame, 1, 1, -5) - assert not _check_equals(memory_dataset.load(), spark_data_frame) - - -def test_save_modify_original_data(spark_data_frame): - """Check that the data set object is not updated when the original - SparkDataFrame is changed.""" - memory_dataset = MemoryDataset() - memory_dataset.save(spark_data_frame) - spark_data_frame = _update_spark_df(spark_data_frame, 1, 1, "new value") - - assert not _check_equals(memory_dataset.load(), spark_data_frame) - - -def test_load_returns_same_spark_object(memory_dataset, spark_data_frame): - """Test that consecutive loads point to the same object in case of - a SparkDataFrame""" - loaded_data = memory_dataset.load() - reloaded_data = memory_dataset.load() - assert _check_equals(loaded_data, spark_data_frame) - assert _check_equals(reloaded_data, spark_data_frame) - assert loaded_data is reloaded_data - - -def test_str_representation(memory_dataset): - """Test string representation of the data set""" - assert "MemoryDataset(data=)" in str(memory_dataset) diff --git a/tests/extras/datasets/spark/test_spark_dataset.py b/tests/extras/datasets/spark/test_spark_dataset.py deleted file mode 100644 index a491ef6aeb..0000000000 --- a/tests/extras/datasets/spark/test_spark_dataset.py +++ /dev/null @@ -1,996 +0,0 @@ -import re -import sys -import tempfile -from pathlib import Path, PurePosixPath - -import boto3 -import pandas as pd -import pytest -from moto import mock_s3 -from pyspark import __version__ -from pyspark.sql import SparkSession -from pyspark.sql.functions import col -from pyspark.sql.types import ( - FloatType, - IntegerType, - StringType, - StructField, - StructType, -) -from pyspark.sql.utils import AnalysisException -from semver import VersionInfo - -from kedro.extras.datasets.pandas import CSVDataSet, ParquetDataSet -from kedro.extras.datasets.pickle import PickleDataSet -from kedro.extras.datasets.spark import SparkDataSet -from kedro.extras.datasets.spark.spark_dataset import ( - _dbfs_exists, - _dbfs_glob, - _get_dbutils, -) -from kedro.io import DataCatalog, DatasetError, Version -from kedro.io.core import generate_timestamp -from kedro.pipeline import node -from kedro.pipeline.modular_pipeline import pipeline as modular_pipeline -from kedro.runner import ParallelRunner, SequentialRunner - -FOLDER_NAME = "fake_folder" -FILENAME = "test.parquet" -BUCKET_NAME = "test_bucket" -SCHEMA_FILE_NAME = "schema.json" -AWS_CREDENTIALS = {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"} - -HDFS_PREFIX = f"{FOLDER_NAME}/{FILENAME}" -HDFS_FOLDER_STRUCTURE = [ - ( - HDFS_PREFIX, - [ - "2019-01-01T23.59.59.999Z", - "2019-01-02T00.00.00.000Z", - "2019-01-02T00.00.00.001Z", - "2019-01-02T01.00.00.000Z", - "2019-02-01T00.00.00.000Z", - ], - [], - ), - (HDFS_PREFIX + "/2019-01-01T23.59.59.999Z", [FILENAME], []), - (HDFS_PREFIX + "/2019-01-01T23.59.59.999Z/" + FILENAME, [], ["part1", "part2"]), - (HDFS_PREFIX + "/2019-01-02T00.00.00.000Z", [], ["other_file"]), - (HDFS_PREFIX + "/2019-01-02T00.00.00.001Z", [], []), - (HDFS_PREFIX + "/2019-01-02T01.00.00.000Z", [FILENAME], []), - (HDFS_PREFIX + "/2019-01-02T01.00.00.000Z/" + FILENAME, [], ["part1"]), - (HDFS_PREFIX + "/2019-02-01T00.00.00.000Z", [], ["other_file"]), -] - -SPARK_VERSION = VersionInfo.parse(__version__) - - -@pytest.fixture -def sample_pandas_df() -> pd.DataFrame: - return pd.DataFrame( - {"Name": ["Alex", "Bob", "Clarke", "Dave"], "Age": [31, 12, 65, 29]} - ) - - -@pytest.fixture -def version(): - load_version = None # use latest - save_version = generate_timestamp() # freeze save version - return Version(load_version, save_version) - - -@pytest.fixture -def versioned_dataset_local(tmp_path, version): - return SparkDataSet(filepath=(tmp_path / FILENAME).as_posix(), version=version) - - -@pytest.fixture -def versioned_dataset_dbfs(tmp_path, version): - return SparkDataSet( - filepath="/dbfs" + (tmp_path / FILENAME).as_posix(), version=version - ) - - -@pytest.fixture -def versioned_dataset_s3(version): - return SparkDataSet( - filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", - version=version, - credentials=AWS_CREDENTIALS, - ) - - -@pytest.fixture -def sample_spark_df(): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] - - return SparkSession.builder.getOrCreate().createDataFrame(data, schema) - - -@pytest.fixture -def sample_spark_df_schema() -> StructType: - return StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("height", FloatType(), True), - ] - ) - - -def identity(arg): - return arg # pragma: no cover - - -@pytest.fixture -def spark_in(tmp_path, sample_spark_df): - spark_in = SparkDataSet(filepath=(tmp_path / "input").as_posix()) - spark_in.save(sample_spark_df) - return spark_in - - -@pytest.fixture -def mocked_s3_bucket(): - """Create a bucket for testing using moto.""" - with mock_s3(): - conn = boto3.client( - "s3", - aws_access_key_id="fake_access_key", - aws_secret_access_key="fake_secret_key", - ) - conn.create_bucket(Bucket=BUCKET_NAME) - yield conn - - -@pytest.fixture -def mocked_s3_schema(tmp_path, mocked_s3_bucket, sample_spark_df_schema: StructType): - """Creates schema file and adds it to mocked S3 bucket.""" - temporary_path = tmp_path / SCHEMA_FILE_NAME - temporary_path.write_text(sample_spark_df_schema.json(), encoding="utf-8") - - mocked_s3_bucket.put_object( - Bucket=BUCKET_NAME, Key=SCHEMA_FILE_NAME, Body=temporary_path.read_bytes() - ) - return mocked_s3_bucket - - -class FileInfo: - def __init__(self, path): - self.path = "dbfs:" + path - - def isDir(self): - return "." not in self.path.split("/")[-1] - - -class TestSparkDataSet: - def test_load_parquet(self, tmp_path, sample_pandas_df): - temp_path = (tmp_path / "data").as_posix() - local_parquet_set = ParquetDataSet(filepath=temp_path) - local_parquet_set.save(sample_pandas_df) - spark_data_set = SparkDataSet(filepath=temp_path) - spark_df = spark_data_set.load() - assert spark_df.count() == 4 - - def test_save_parquet(self, tmp_path, sample_spark_df): - # To cross check the correct Spark save operation we save to - # a single spark partition and retrieve it with Kedro - # ParquetDataSet - temp_dir = Path(str(tmp_path / "test_data")) - spark_data_set = SparkDataSet( - filepath=temp_dir.as_posix(), save_args={"compression": "none"} - ) - spark_df = sample_spark_df.coalesce(1) - spark_data_set.save(spark_df) - - single_parquet = [ - f for f in temp_dir.iterdir() if f.is_file() and f.name.startswith("part") - ][0] - - local_parquet_data_set = ParquetDataSet(filepath=single_parquet.as_posix()) - - pandas_df = local_parquet_data_set.load() - - assert pandas_df[pandas_df["name"] == "Bob"]["age"].iloc[0] == 12 - - def test_load_options_csv(self, tmp_path, sample_pandas_df): - filepath = (tmp_path / "data").as_posix() - local_csv_data_set = CSVDataSet(filepath=filepath) - local_csv_data_set.save(sample_pandas_df) - spark_data_set = SparkDataSet( - filepath=filepath, file_format="csv", load_args={"header": True} - ) - spark_df = spark_data_set.load() - assert spark_df.filter(col("Name") == "Alex").count() == 1 - - def test_load_options_schema_ddl_string( - self, tmp_path, sample_pandas_df, sample_spark_df_schema - ): - filepath = (tmp_path / "data").as_posix() - local_csv_data_set = CSVDataSet(filepath=filepath) - local_csv_data_set.save(sample_pandas_df) - spark_data_set = SparkDataSet( - filepath=filepath, - file_format="csv", - load_args={"header": True, "schema": "name STRING, age INT, height FLOAT"}, - ) - spark_df = spark_data_set.load() - assert spark_df.schema == sample_spark_df_schema - - def test_load_options_schema_obj( - self, tmp_path, sample_pandas_df, sample_spark_df_schema - ): - filepath = (tmp_path / "data").as_posix() - local_csv_data_set = CSVDataSet(filepath=filepath) - local_csv_data_set.save(sample_pandas_df) - - spark_data_set = SparkDataSet( - filepath=filepath, - file_format="csv", - load_args={"header": True, "schema": sample_spark_df_schema}, - ) - - spark_df = spark_data_set.load() - assert spark_df.schema == sample_spark_df_schema - - def test_load_options_schema_path( - self, tmp_path, sample_pandas_df, sample_spark_df_schema - ): - filepath = (tmp_path / "data").as_posix() - schemapath = (tmp_path / SCHEMA_FILE_NAME).as_posix() - local_csv_data_set = CSVDataSet(filepath=filepath) - local_csv_data_set.save(sample_pandas_df) - Path(schemapath).write_text(sample_spark_df_schema.json(), encoding="utf-8") - - spark_data_set = SparkDataSet( - filepath=filepath, - file_format="csv", - load_args={"header": True, "schema": {"filepath": schemapath}}, - ) - - spark_df = spark_data_set.load() - assert spark_df.schema == sample_spark_df_schema - - @pytest.mark.usefixtures("mocked_s3_schema") - def test_load_options_schema_path_with_credentials( - self, tmp_path, sample_pandas_df, sample_spark_df_schema - ): - filepath = (tmp_path / "data").as_posix() - local_csv_data_set = CSVDataSet(filepath=filepath) - local_csv_data_set.save(sample_pandas_df) - - spark_data_set = SparkDataSet( - filepath=filepath, - file_format="csv", - load_args={ - "header": True, - "schema": { - "filepath": f"s3://{BUCKET_NAME}/{SCHEMA_FILE_NAME}", - "credentials": AWS_CREDENTIALS, - }, - }, - ) - - spark_df = spark_data_set.load() - assert spark_df.schema == sample_spark_df_schema - - def test_load_options_invalid_schema_file(self, tmp_path): - filepath = (tmp_path / "data").as_posix() - schemapath = (tmp_path / SCHEMA_FILE_NAME).as_posix() - Path(schemapath).write_text("dummy", encoding="utf-8") - - pattern = ( - f"Contents of 'schema.filepath' ({schemapath}) are invalid. Please" - f"provide a valid JSON-serialised 'pyspark.sql.types.StructType'." - ) - - with pytest.raises(DatasetError, match=re.escape(pattern)): - SparkDataSet( - filepath=filepath, - file_format="csv", - load_args={"header": True, "schema": {"filepath": schemapath}}, - ) - - def test_load_options_invalid_schema(self, tmp_path): - filepath = (tmp_path / "data").as_posix() - - pattern = ( - "Schema load argument does not specify a 'filepath' attribute. Please" - "include a path to a JSON-serialised 'pyspark.sql.types.StructType'." - ) - - with pytest.raises(DatasetError, match=pattern): - SparkDataSet( - filepath=filepath, - file_format="csv", - load_args={"header": True, "schema": {}}, - ) - - def test_save_options_csv(self, tmp_path, sample_spark_df): - # To cross check the correct Spark save operation we save to - # a single spark partition with csv format and retrieve it with Kedro - # CSVDataSet - temp_dir = Path(str(tmp_path / "test_data")) - spark_data_set = SparkDataSet( - filepath=temp_dir.as_posix(), - file_format="csv", - save_args={"sep": "|", "header": True}, - ) - spark_df = sample_spark_df.coalesce(1) - spark_data_set.save(spark_df) - - single_csv_file = [ - f for f in temp_dir.iterdir() if f.is_file() and f.suffix == ".csv" - ][0] - - csv_local_data_set = CSVDataSet( - filepath=single_csv_file.as_posix(), load_args={"sep": "|"} - ) - pandas_df = csv_local_data_set.load() - - assert pandas_df[pandas_df["name"] == "Alex"]["age"][0] == 31 - - def test_str_representation(self): - with tempfile.NamedTemporaryFile() as temp_data_file: - filepath = Path(temp_data_file.name).as_posix() - spark_data_set = SparkDataSet( - filepath=filepath, file_format="csv", load_args={"header": True} - ) - assert "SparkDataSet" in str(spark_data_set) - assert f"filepath={filepath}" in str(spark_data_set) - - def test_save_overwrite_fail(self, tmp_path, sample_spark_df): - # Writes a data frame twice and expects it to fail. - filepath = (tmp_path / "test_data").as_posix() - spark_data_set = SparkDataSet(filepath=filepath) - spark_data_set.save(sample_spark_df) - - with pytest.raises(DatasetError): - spark_data_set.save(sample_spark_df) - - def test_save_overwrite_mode(self, tmp_path, sample_spark_df): - # Writes a data frame in overwrite mode. - filepath = (tmp_path / "test_data").as_posix() - spark_data_set = SparkDataSet( - filepath=filepath, save_args={"mode": "overwrite"} - ) - - spark_data_set.save(sample_spark_df) - spark_data_set.save(sample_spark_df) - - @pytest.mark.parametrize("mode", ["merge", "delete", "update"]) - def test_file_format_delta_and_unsupported_mode(self, tmp_path, mode): - filepath = (tmp_path / "test_data").as_posix() - pattern = ( - f"It is not possible to perform 'save()' for file format 'delta' " - f"with mode '{mode}' on 'SparkDataSet'. " - f"Please use 'spark.DeltaTableDataSet' instead." - ) - - with pytest.raises(DatasetError, match=re.escape(pattern)): - _ = SparkDataSet( - filepath=filepath, file_format="delta", save_args={"mode": mode} - ) - - def test_save_partition(self, tmp_path, sample_spark_df): - # To verify partitioning this test will partition the data by one - # of the columns and then check whether partitioned column is added - # to the save path - - filepath = Path(str(tmp_path / "test_data")) - spark_data_set = SparkDataSet( - filepath=filepath.as_posix(), - save_args={"mode": "overwrite", "partitionBy": ["name"]}, - ) - - spark_data_set.save(sample_spark_df) - - expected_path = filepath / "name=Alex" - - assert expected_path.exists() - - @pytest.mark.parametrize("file_format", ["csv", "parquet", "delta"]) - def test_exists(self, file_format, tmp_path, sample_spark_df): - filepath = (tmp_path / "test_data").as_posix() - spark_data_set = SparkDataSet(filepath=filepath, file_format=file_format) - - assert not spark_data_set.exists() - - spark_data_set.save(sample_spark_df) - assert spark_data_set.exists() - - def test_exists_raises_error(self, mocker): - # exists should raise all errors except for - # AnalysisExceptions clearly indicating a missing file - spark_data_set = SparkDataSet(filepath="") - if SPARK_VERSION.match(">=3.4.0"): - mocker.patch.object( - spark_data_set, - "_get_spark", - side_effect=AnalysisException("Other Exception"), - ) - else: - mocker.patch.object( # pylint: disable=expression-not-assigned - spark_data_set, - "_get_spark", - side_effect=AnalysisException("Other Exception", []), - ) - - with pytest.raises(DatasetError, match="Other Exception"): - spark_data_set.exists() - - @pytest.mark.parametrize("is_async", [False, True]) - def test_parallel_runner(self, is_async, spark_in): - """Test ParallelRunner with SparkDataSet fails.""" - catalog = DataCatalog(data_sets={"spark_in": spark_in}) - pipeline = modular_pipeline([node(identity, "spark_in", "spark_out")]) - pattern = ( - r"The following data sets cannot be used with " - r"multiprocessing: \['spark_in'\]" - ) - with pytest.raises(AttributeError, match=pattern): - ParallelRunner(is_async=is_async).run(pipeline, catalog) - - def test_s3_glob_refresh(self): - spark_dataset = SparkDataSet(filepath="s3a://bucket/data") - assert spark_dataset._glob_function.keywords == {"refresh": True} - - def test_copy(self): - spark_dataset = SparkDataSet( - filepath="/tmp/data", save_args={"mode": "overwrite"} - ) - assert spark_dataset._file_format == "parquet" - - spark_dataset_copy = spark_dataset._copy(_file_format="csv") - - assert spark_dataset is not spark_dataset_copy - assert spark_dataset._file_format == "parquet" - assert spark_dataset._save_args == {"mode": "overwrite"} - assert spark_dataset_copy._file_format == "csv" - assert spark_dataset_copy._save_args == {"mode": "overwrite"} - - -class TestSparkDataSetVersionedLocal: - def test_no_version(self, versioned_dataset_local): - pattern = r"Did not find any versions for SparkDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_dataset_local.load() - - def test_load_latest(self, versioned_dataset_local, sample_spark_df): - versioned_dataset_local.save(sample_spark_df) - reloaded = versioned_dataset_local.load() - - assert reloaded.exceptAll(sample_spark_df).count() == 0 - - def test_load_exact(self, tmp_path, sample_spark_df): - ts = generate_timestamp() - ds_local = SparkDataSet( - filepath=(tmp_path / FILENAME).as_posix(), version=Version(ts, ts) - ) - - ds_local.save(sample_spark_df) - reloaded = ds_local.load() - - assert reloaded.exceptAll(sample_spark_df).count() == 0 - - def test_save(self, versioned_dataset_local, version, tmp_path, sample_spark_df): - versioned_dataset_local.save(sample_spark_df) - assert (tmp_path / FILENAME / version.save / FILENAME).exists() - - def test_repr(self, versioned_dataset_local, tmp_path, version): - assert f"version=Version(load=None, save='{version.save}')" in str( - versioned_dataset_local - ) - - dataset_local = SparkDataSet(filepath=(tmp_path / FILENAME).as_posix()) - assert "version=" not in str(dataset_local) - - def test_save_version_warning(self, tmp_path, sample_spark_df): - exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") - ds_local = SparkDataSet( - filepath=(tmp_path / FILENAME).as_posix(), version=exact_version - ) - - pattern = ( - r"Save version '{ev.save}' did not match load version " - r"'{ev.load}' for SparkDataSet\(.+\)".format(ev=exact_version) - ) - with pytest.warns(UserWarning, match=pattern): - ds_local.save(sample_spark_df) - - def test_prevent_overwrite(self, tmp_path, version, sample_spark_df): - versioned_local = SparkDataSet( - filepath=(tmp_path / FILENAME).as_posix(), - version=version, - # second save should fail even in overwrite mode - save_args={"mode": "overwrite"}, - ) - versioned_local.save(sample_spark_df) - - pattern = ( - r"Save path '.+' for SparkDataSet\(.+\) must not exist " - r"if versioning is enabled" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_local.save(sample_spark_df) - - def test_versioning_existing_dataset( - self, versioned_dataset_local, sample_spark_df - ): - """Check behavior when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset. Note: because SparkDataSet saves to a - directory even if non-versioned, an error is not expected.""" - spark_data_set = SparkDataSet( - filepath=versioned_dataset_local._filepath.as_posix() - ) - spark_data_set.save(sample_spark_df) - assert spark_data_set.exists() - versioned_dataset_local.save(sample_spark_df) - assert versioned_dataset_local.exists() - - -@pytest.mark.skipif( - sys.platform.startswith("win"), reason="DBFS doesn't work on Windows" -) -class TestSparkDataSetVersionedDBFS: - def test_load_latest( # noqa: too-many-arguments - self, mocker, versioned_dataset_dbfs, version, tmp_path, sample_spark_df - ): - mocked_glob = mocker.patch.object(versioned_dataset_dbfs, "_glob_function") - mocked_glob.return_value = [str(tmp_path / FILENAME / version.save / FILENAME)] - - versioned_dataset_dbfs.save(sample_spark_df) - reloaded = versioned_dataset_dbfs.load() - - expected_calls = [ - mocker.call("/dbfs" + str(tmp_path / FILENAME / "*" / FILENAME)) - ] - assert mocked_glob.call_args_list == expected_calls - - assert reloaded.exceptAll(sample_spark_df).count() == 0 - - def test_load_exact(self, tmp_path, sample_spark_df): - ts = generate_timestamp() - ds_dbfs = SparkDataSet( - filepath="/dbfs" + str(tmp_path / FILENAME), version=Version(ts, ts) - ) - - ds_dbfs.save(sample_spark_df) - reloaded = ds_dbfs.load() - - assert reloaded.exceptAll(sample_spark_df).count() == 0 - - def test_save( # noqa: too-many-arguments - self, mocker, versioned_dataset_dbfs, version, tmp_path, sample_spark_df - ): - mocked_glob = mocker.patch.object(versioned_dataset_dbfs, "_glob_function") - mocked_glob.return_value = [str(tmp_path / FILENAME / version.save / FILENAME)] - - versioned_dataset_dbfs.save(sample_spark_df) - - mocked_glob.assert_called_once_with( - "/dbfs" + str(tmp_path / FILENAME / "*" / FILENAME) - ) - assert (tmp_path / FILENAME / version.save / FILENAME).exists() - - def test_exists( # noqa: too-many-arguments - self, mocker, versioned_dataset_dbfs, version, tmp_path, sample_spark_df - ): - mocked_glob = mocker.patch.object(versioned_dataset_dbfs, "_glob_function") - mocked_glob.return_value = [str(tmp_path / FILENAME / version.save / FILENAME)] - - assert not versioned_dataset_dbfs.exists() - - versioned_dataset_dbfs.save(sample_spark_df) - assert versioned_dataset_dbfs.exists() - - expected_calls = [ - mocker.call("/dbfs" + str(tmp_path / FILENAME / "*" / FILENAME)) - ] * 2 - assert mocked_glob.call_args_list == expected_calls - - def test_dbfs_glob(self, mocker): - dbutils_mock = mocker.Mock() - dbutils_mock.fs.ls.return_value = [ - FileInfo("/tmp/file/date1"), - FileInfo("/tmp/file/date2"), - FileInfo("/tmp/file/file.csv"), - FileInfo("/tmp/file/"), - ] - pattern = "/tmp/file/*/file" - expected = ["/dbfs/tmp/file/date1/file", "/dbfs/tmp/file/date2/file"] - - result = _dbfs_glob(pattern, dbutils_mock) - assert result == expected - dbutils_mock.fs.ls.assert_called_once_with("/tmp/file") - - def test_dbfs_exists(self, mocker): - dbutils_mock = mocker.Mock() - test_path = "/dbfs/tmp/file/date1/file" - dbutils_mock.fs.ls.return_value = [ - FileInfo("/tmp/file/date1"), - FileInfo("/tmp/file/date2"), - FileInfo("/tmp/file/file.csv"), - FileInfo("/tmp/file/"), - ] - - assert _dbfs_exists(test_path, dbutils_mock) - - # add side effect to test that non-existence is handled - dbutils_mock.fs.ls.side_effect = Exception() - assert not _dbfs_exists(test_path, dbutils_mock) - - def test_ds_init_no_dbutils(self, mocker): - get_dbutils_mock = mocker.patch( - "kedro.extras.datasets.spark.spark_dataset._get_dbutils", return_value=None - ) - - data_set = SparkDataSet(filepath="/dbfs/tmp/data") - - get_dbutils_mock.assert_called_once() - assert data_set._glob_function.__name__ == "iglob" - - def test_ds_init_dbutils_available(self, mocker): - get_dbutils_mock = mocker.patch( - "kedro.extras.datasets.spark.spark_dataset._get_dbutils", - return_value="mock", - ) - - data_set = SparkDataSet(filepath="/dbfs/tmp/data") - - get_dbutils_mock.assert_called_once() - assert data_set._glob_function.__class__.__name__ == "partial" - assert data_set._glob_function.func.__name__ == "_dbfs_glob" - assert data_set._glob_function.keywords == { - "dbutils": get_dbutils_mock.return_value - } - - def test_get_dbutils_from_globals(self, mocker): - mocker.patch( - "kedro.extras.datasets.spark.spark_dataset.globals", - return_value={"dbutils": "dbutils_from_globals"}, - ) - assert _get_dbutils("spark") == "dbutils_from_globals" - - def test_get_dbutils_from_pyspark(self, mocker): - dbutils_mock = mocker.Mock() - dbutils_mock.DBUtils.return_value = "dbutils_from_pyspark" - mocker.patch.dict("sys.modules", {"pyspark.dbutils": dbutils_mock}) - assert _get_dbutils("spark") == "dbutils_from_pyspark" - dbutils_mock.DBUtils.assert_called_once_with("spark") - - def test_get_dbutils_from_ipython(self, mocker): - ipython_mock = mocker.Mock() - ipython_mock.get_ipython.return_value.user_ns = { - "dbutils": "dbutils_from_ipython" - } - mocker.patch.dict("sys.modules", {"IPython": ipython_mock}) - assert _get_dbutils("spark") == "dbutils_from_ipython" - ipython_mock.get_ipython.assert_called_once_with() - - def test_get_dbutils_no_modules(self, mocker): - mocker.patch( - "kedro.extras.datasets.spark.spark_dataset.globals", return_value={} - ) - mocker.patch.dict("sys.modules", {"pyspark": None, "IPython": None}) - assert _get_dbutils("spark") is None - - @pytest.mark.parametrize("os_name", ["nt", "posix"]) - def test_regular_path_in_different_os(self, os_name, mocker): - """Check that class of filepath depends on OS for regular path.""" - mocker.patch("os.name", os_name) - data_set = SparkDataSet(filepath="/some/path") - assert isinstance(data_set._filepath, PurePosixPath) - - @pytest.mark.parametrize("os_name", ["nt", "posix"]) - def test_dbfs_path_in_different_os(self, os_name, mocker): - """Check that class of filepath doesn't depend on OS if it references DBFS.""" - mocker.patch("os.name", os_name) - data_set = SparkDataSet(filepath="/dbfs/some/path") - assert isinstance(data_set._filepath, PurePosixPath) - - -class TestSparkDataSetVersionedS3: - def test_no_version(self, versioned_dataset_s3): - pattern = r"Did not find any versions for SparkDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_dataset_s3.load() - - def test_load_latest(self, mocker, versioned_dataset_s3): - get_spark = mocker.patch.object(versioned_dataset_s3, "_get_spark") - mocked_glob = mocker.patch.object(versioned_dataset_s3, "_glob_function") - mocked_glob.return_value = [ - "{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v="mocked_version") - ] - mocker.patch.object(versioned_dataset_s3, "_exists_function", return_value=True) - - versioned_dataset_s3.load() - - mocked_glob.assert_called_once_with( - "{b}/{f}/*/{f}".format(b=BUCKET_NAME, f=FILENAME) - ) - get_spark.return_value.read.load.assert_called_once_with( - "s3a://{b}/{f}/{v}/{f}".format( - b=BUCKET_NAME, f=FILENAME, v="mocked_version" - ), - "parquet", - ) - - def test_load_exact(self, mocker): - ts = generate_timestamp() - ds_s3 = SparkDataSet( - filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", - version=Version(ts, None), - ) - get_spark = mocker.patch.object(ds_s3, "_get_spark") - - ds_s3.load() - - get_spark.return_value.read.load.assert_called_once_with( - "s3a://{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v=ts), "parquet" - ) - - def test_save(self, versioned_dataset_s3, version, mocker): - mocked_spark_df = mocker.Mock() - - # need resolve_load_version() call to return a load version that - # matches save version due to consistency check in versioned_dataset_s3.save() - mocker.patch.object( - versioned_dataset_s3, "resolve_load_version", return_value=version.save - ) - - versioned_dataset_s3.save(mocked_spark_df) - mocked_spark_df.write.save.assert_called_once_with( - "s3a://{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v=version.save), - "parquet", - ) - - def test_save_version_warning(self, mocker): - exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") - ds_s3 = SparkDataSet( - filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", - version=exact_version, - credentials=AWS_CREDENTIALS, - ) - mocked_spark_df = mocker.Mock() - - pattern = ( - r"Save version '{ev.save}' did not match load version " - r"'{ev.load}' for SparkDataSet\(.+\)".format(ev=exact_version) - ) - with pytest.warns(UserWarning, match=pattern): - ds_s3.save(mocked_spark_df) - mocked_spark_df.write.save.assert_called_once_with( - "s3a://{b}/{f}/{v}/{f}".format( - b=BUCKET_NAME, f=FILENAME, v=exact_version.save - ), - "parquet", - ) - - def test_prevent_overwrite(self, mocker, versioned_dataset_s3): - mocked_spark_df = mocker.Mock() - mocker.patch.object(versioned_dataset_s3, "_exists_function", return_value=True) - - pattern = ( - r"Save path '.+' for SparkDataSet\(.+\) must not exist " - r"if versioning is enabled" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_dataset_s3.save(mocked_spark_df) - - mocked_spark_df.write.save.assert_not_called() - - def test_s3n_warning(self, version): - pattern = ( - "'s3n' filesystem has now been deprecated by Spark, " - "please consider switching to 's3a'" - ) - with pytest.warns(DeprecationWarning, match=pattern): - SparkDataSet(filepath=f"s3n://{BUCKET_NAME}/{FILENAME}", version=version) - - def test_repr(self, versioned_dataset_s3, version): - assert "filepath=s3a://" in str(versioned_dataset_s3) - assert f"version=Version(load=None, save='{version.save}')" in str( - versioned_dataset_s3 - ) - - dataset_s3 = SparkDataSet(filepath=f"s3a://{BUCKET_NAME}/{FILENAME}") - assert "filepath=s3a://" in str(dataset_s3) - assert "version=" not in str(dataset_s3) - - -class TestSparkDataSetVersionedHdfs: - def test_no_version(self, mocker, version): - hdfs_walk = mocker.patch( - "kedro.extras.datasets.spark.spark_dataset.InsecureClient.walk" - ) - hdfs_walk.return_value = [] - - versioned_hdfs = SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}", version=version) - - pattern = r"Did not find any versions for SparkDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_hdfs.load() - - hdfs_walk.assert_called_once_with(HDFS_PREFIX) - - def test_load_latest(self, mocker, version): - mocker.patch( - "kedro.extras.datasets.spark.spark_dataset.InsecureClient.status", - return_value=True, - ) - hdfs_walk = mocker.patch( - "kedro.extras.datasets.spark.spark_dataset.InsecureClient.walk" - ) - hdfs_walk.return_value = HDFS_FOLDER_STRUCTURE - - versioned_hdfs = SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}", version=version) - get_spark = mocker.patch.object(versioned_hdfs, "_get_spark") - - versioned_hdfs.load() - - hdfs_walk.assert_called_once_with(HDFS_PREFIX) - get_spark.return_value.read.load.assert_called_once_with( - "hdfs://{fn}/{f}/{v}/{f}".format( - fn=FOLDER_NAME, v="2019-01-02T01.00.00.000Z", f=FILENAME - ), - "parquet", - ) - - def test_load_exact(self, mocker): - ts = generate_timestamp() - versioned_hdfs = SparkDataSet( - filepath=f"hdfs://{HDFS_PREFIX}", version=Version(ts, None) - ) - get_spark = mocker.patch.object(versioned_hdfs, "_get_spark") - - versioned_hdfs.load() - - get_spark.return_value.read.load.assert_called_once_with( - "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, f=FILENAME, v=ts), - "parquet", - ) - - def test_save(self, mocker, version): - hdfs_status = mocker.patch( - "kedro.extras.datasets.spark.spark_dataset.InsecureClient.status" - ) - hdfs_status.return_value = None - - versioned_hdfs = SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}", version=version) - - # need resolve_load_version() call to return a load version that - # matches save version due to consistency check in versioned_hdfs.save() - mocker.patch.object( - versioned_hdfs, "resolve_load_version", return_value=version.save - ) - - mocked_spark_df = mocker.Mock() - versioned_hdfs.save(mocked_spark_df) - - hdfs_status.assert_called_once_with( - "{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, v=version.save, f=FILENAME), - strict=False, - ) - mocked_spark_df.write.save.assert_called_once_with( - "hdfs://{fn}/{f}/{v}/{f}".format( - fn=FOLDER_NAME, v=version.save, f=FILENAME - ), - "parquet", - ) - - def test_save_version_warning(self, mocker): - exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") - versioned_hdfs = SparkDataSet( - filepath=f"hdfs://{HDFS_PREFIX}", version=exact_version - ) - mocker.patch.object(versioned_hdfs, "_exists_function", return_value=False) - mocked_spark_df = mocker.Mock() - - pattern = ( - r"Save version '{ev.save}' did not match load version " - r"'{ev.load}' for SparkDataSet\(.+\)".format(ev=exact_version) - ) - - with pytest.warns(UserWarning, match=pattern): - versioned_hdfs.save(mocked_spark_df) - mocked_spark_df.write.save.assert_called_once_with( - "hdfs://{fn}/{f}/{sv}/{f}".format( - fn=FOLDER_NAME, f=FILENAME, sv=exact_version.save - ), - "parquet", - ) - - def test_prevent_overwrite(self, mocker, version): - hdfs_status = mocker.patch( - "kedro.extras.datasets.spark.spark_dataset.InsecureClient.status" - ) - hdfs_status.return_value = True - - versioned_hdfs = SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}", version=version) - - mocked_spark_df = mocker.Mock() - - pattern = ( - r"Save path '.+' for SparkDataSet\(.+\) must not exist " - r"if versioning is enabled" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_hdfs.save(mocked_spark_df) - - hdfs_status.assert_called_once_with( - "{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, v=version.save, f=FILENAME), - strict=False, - ) - mocked_spark_df.write.save.assert_not_called() - - def test_hdfs_warning(self, version): - pattern = ( - "HDFS filesystem support for versioned SparkDataSet is in beta " - "and uses 'hdfs.client.InsecureClient', please use with caution" - ) - with pytest.warns(UserWarning, match=pattern): - SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}", version=version) - - def test_repr(self, version): - versioned_hdfs = SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}", version=version) - assert "filepath=hdfs://" in str(versioned_hdfs) - assert f"version=Version(load=None, save='{version.save}')" in str( - versioned_hdfs - ) - - dataset_hdfs = SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}") - assert "filepath=hdfs://" in str(dataset_hdfs) - assert "version=" not in str(dataset_hdfs) - - -@pytest.fixture -def data_catalog(tmp_path): - source_path = Path(__file__).parent / "data/test.parquet" - spark_in = SparkDataSet(source_path.as_posix()) - spark_out = SparkDataSet((tmp_path / "spark_data").as_posix()) - pickle_ds = PickleDataSet((tmp_path / "pickle/test.pkl").as_posix()) - - return DataCatalog( - {"spark_in": spark_in, "spark_out": spark_out, "pickle_ds": pickle_ds} - ) - - -@pytest.mark.parametrize("is_async", [False, True]) -class TestDataFlowSequentialRunner: - def test_spark_load_save(self, is_async, data_catalog): - """SparkDataSet(load) -> node -> Spark (save).""" - pipeline = modular_pipeline([node(identity, "spark_in", "spark_out")]) - SequentialRunner(is_async=is_async).run(pipeline, data_catalog) - - save_path = Path(data_catalog._data_sets["spark_out"]._filepath.as_posix()) - files = list(save_path.glob("*.parquet")) - assert len(files) > 0 - - def test_spark_pickle(self, is_async, data_catalog): - """SparkDataSet(load) -> node -> PickleDataSet (save)""" - pipeline = modular_pipeline([node(identity, "spark_in", "pickle_ds")]) - pattern = ".* was not serialised due to.*" - with pytest.raises(DatasetError, match=pattern): - SequentialRunner(is_async=is_async).run(pipeline, data_catalog) - - def test_spark_memory_spark(self, is_async, data_catalog): - """SparkDataSet(load) -> node -> MemoryDataSet (save and then load) -> - node -> SparkDataSet (save)""" - pipeline = modular_pipeline( - [ - node(identity, "spark_in", "memory_ds"), - node(identity, "memory_ds", "spark_out"), - ] - ) - SequentialRunner(is_async=is_async).run(pipeline, data_catalog) - - save_path = Path(data_catalog._data_sets["spark_out"]._filepath.as_posix()) - files = list(save_path.glob("*.parquet")) - assert len(files) > 0 diff --git a/tests/extras/datasets/spark/test_spark_hive_dataset.py b/tests/extras/datasets/spark/test_spark_hive_dataset.py deleted file mode 100644 index 399ebc4169..0000000000 --- a/tests/extras/datasets/spark/test_spark_hive_dataset.py +++ /dev/null @@ -1,311 +0,0 @@ -import gc -import re -from pathlib import Path -from tempfile import TemporaryDirectory - -import pytest -from psutil import Popen -from pyspark import SparkContext -from pyspark.sql import SparkSession -from pyspark.sql.types import IntegerType, StringType, StructField, StructType - -from kedro.extras.datasets.spark import SparkHiveDataSet -from kedro.io import DatasetError - -TESTSPARKDIR = "test_spark_dir" - - -@pytest.fixture(scope="module") -def spark_session(): - try: - with TemporaryDirectory(TESTSPARKDIR) as tmpdir: - spark = ( - SparkSession.builder.config( - "spark.local.dir", (Path(tmpdir) / "spark_local").absolute() - ) - .config( - "spark.sql.warehouse.dir", (Path(tmpdir) / "warehouse").absolute() - ) - .config( - "javax.jdo.option.ConnectionURL", - f"jdbc:derby:;" - f"databaseName={(Path(tmpdir) / 'warehouse_db').absolute()};" - f"create=true", - ) - .enableHiveSupport() - .getOrCreate() - ) - spark.sparkContext.setCheckpointDir( - str((Path(tmpdir) / "spark_checkpoint").absolute()) - ) - yield spark - - # This fixture should be a dependency of other fixtures dealing with spark hive data - # in this module so that it always exits last and stops the spark session - # after tests are finished. - spark.stop() - except PermissionError: # pragma: no cover - # On Windows machine TemporaryDirectory can't be removed because some - # files are still used by Java process. - pass - - # remove the cached JVM vars - SparkContext._jvm = None # pylint: disable=protected-access - SparkContext._gateway = None # pylint: disable=protected-access - - # py4j doesn't shutdown properly so kill the actual JVM process - for obj in gc.get_objects(): - try: - if isinstance(obj, Popen) and "pyspark" in obj.args[0]: - obj.terminate() # pragma: no cover - except ReferenceError: # pragma: no cover - # gc.get_objects may return dead weak proxy objects that will raise - # ReferenceError when you isinstance them - pass - - -@pytest.fixture(scope="module", autouse=True) -def spark_test_databases(spark_session): - """Setup spark test databases for all tests in this module.""" - dataset = _generate_spark_df_one() - dataset.createOrReplaceTempView("tmp") - databases = ["default_1", "default_2"] - - # Setup the databases and test table before testing - for database in databases: - spark_session.sql(f"create database {database}") - spark_session.sql("use default_1") - spark_session.sql("create table table_1 as select * from tmp") - - yield spark_session - - # Drop the databases after testing - for database in databases: - spark_session.sql(f"drop database {database} cascade") - - -def assert_df_equal(expected, result): - def indexRDD(data_frame): - return data_frame.rdd.zipWithIndex().map(lambda x: (x[1], x[0])) - - index_expected = indexRDD(expected) - index_result = indexRDD(result) - assert ( - index_expected.cogroup(index_result) - .map(lambda x: tuple(map(list, x[1]))) - .filter(lambda x: x[0] != x[1]) - .take(1) - == [] - ) - - -def _generate_spark_df_one(): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] - return SparkSession.builder.getOrCreate().createDataFrame(data, schema).coalesce(1) - - -def _generate_spark_df_upsert(): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - data = [("Alex", 99), ("Jeremy", 55)] - return SparkSession.builder.getOrCreate().createDataFrame(data, schema).coalesce(1) - - -def _generate_spark_df_upsert_expected(): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - data = [("Alex", 99), ("Bob", 12), ("Clarke", 65), ("Dave", 29), ("Jeremy", 55)] - return SparkSession.builder.getOrCreate().createDataFrame(data, schema).coalesce(1) - - -class TestSparkHiveDataSet: - def test_cant_pickle(self): - import pickle # pylint: disable=import-outside-toplevel - - with pytest.raises(pickle.PicklingError): - pickle.dumps( - SparkHiveDataSet( - database="default_1", table="table_1", write_mode="overwrite" - ) - ) - - def test_read_existing_table(self): - dataset = SparkHiveDataSet( - database="default_1", table="table_1", write_mode="overwrite", save_args={} - ) - assert_df_equal(_generate_spark_df_one(), dataset.load()) - - def test_overwrite_empty_table(self, spark_session): - spark_session.sql( - "create table default_1.test_overwrite_empty_table (name string, age integer)" - ).take(1) - dataset = SparkHiveDataSet( - database="default_1", - table="test_overwrite_empty_table", - write_mode="overwrite", - ) - dataset.save(_generate_spark_df_one()) - assert_df_equal(dataset.load(), _generate_spark_df_one()) - - def test_overwrite_not_empty_table(self, spark_session): - spark_session.sql( - "create table default_1.test_overwrite_full_table (name string, age integer)" - ).take(1) - dataset = SparkHiveDataSet( - database="default_1", - table="test_overwrite_full_table", - write_mode="overwrite", - ) - dataset.save(_generate_spark_df_one()) - dataset.save(_generate_spark_df_one()) - assert_df_equal(dataset.load(), _generate_spark_df_one()) - - def test_insert_not_empty_table(self, spark_session): - spark_session.sql( - "create table default_1.test_insert_not_empty_table (name string, age integer)" - ).take(1) - dataset = SparkHiveDataSet( - database="default_1", - table="test_insert_not_empty_table", - write_mode="append", - ) - dataset.save(_generate_spark_df_one()) - dataset.save(_generate_spark_df_one()) - assert_df_equal( - dataset.load(), _generate_spark_df_one().union(_generate_spark_df_one()) - ) - - def test_upsert_config_err(self): - # no pk provided should prompt config error - with pytest.raises( - DatasetError, match="'table_pk' must be set to utilise 'upsert' read mode" - ): - SparkHiveDataSet(database="default_1", table="table_1", write_mode="upsert") - - def test_upsert_empty_table(self, spark_session): - spark_session.sql( - "create table default_1.test_upsert_empty_table (name string, age integer)" - ).take(1) - dataset = SparkHiveDataSet( - database="default_1", - table="test_upsert_empty_table", - write_mode="upsert", - table_pk=["name"], - ) - dataset.save(_generate_spark_df_one()) - assert_df_equal( - dataset.load().sort("name"), _generate_spark_df_one().sort("name") - ) - - def test_upsert_not_empty_table(self, spark_session): - spark_session.sql( - "create table default_1.test_upsert_not_empty_table (name string, age integer)" - ).take(1) - dataset = SparkHiveDataSet( - database="default_1", - table="test_upsert_not_empty_table", - write_mode="upsert", - table_pk=["name"], - ) - dataset.save(_generate_spark_df_one()) - dataset.save(_generate_spark_df_upsert()) - - assert_df_equal( - dataset.load().sort("name"), - _generate_spark_df_upsert_expected().sort("name"), - ) - - def test_invalid_pk_provided(self): - _test_columns = ["column_doesnt_exist"] - dataset = SparkHiveDataSet( - database="default_1", - table="table_1", - write_mode="upsert", - table_pk=_test_columns, - ) - with pytest.raises( - DatasetError, - match=re.escape( - f"Columns {str(_test_columns)} selected as primary key(s) " - f"not found in table default_1.table_1", - ), - ): - dataset.save(_generate_spark_df_one()) - - def test_invalid_write_mode_provided(self): - pattern = ( - "Invalid 'write_mode' provided: not_a_write_mode. " - "'write_mode' must be one of: " - "append, error, errorifexists, upsert, overwrite" - ) - with pytest.raises(DatasetError, match=re.escape(pattern)): - SparkHiveDataSet( - database="default_1", - table="table_1", - write_mode="not_a_write_mode", - table_pk=["name"], - ) - - def test_invalid_schema_insert(self, spark_session): - spark_session.sql( - "create table default_1.test_invalid_schema_insert " - "(name string, additional_column_on_hive integer)" - ).take(1) - dataset = SparkHiveDataSet( - database="default_1", - table="test_invalid_schema_insert", - write_mode="append", - ) - with pytest.raises( - DatasetError, - match=r"Dataset does not match hive table schema\.\n" - r"Present on insert only: \[\('age', 'int'\)\]\n" - r"Present on schema only: \[\('additional_column_on_hive', 'int'\)\]", - ): - dataset.save(_generate_spark_df_one()) - - def test_insert_to_non_existent_table(self): - dataset = SparkHiveDataSet( - database="default_1", table="table_not_yet_created", write_mode="append" - ) - dataset.save(_generate_spark_df_one()) - assert_df_equal( - dataset.load().sort("name"), _generate_spark_df_one().sort("name") - ) - - def test_read_from_non_existent_table(self): - dataset = SparkHiveDataSet( - database="default_1", table="table_doesnt_exist", write_mode="append" - ) - with pytest.raises( - DatasetError, - match=r"Failed while loading data from data set SparkHiveDataSet" - r"|table_doesnt_exist" - r"|UnresolvedRelation", - ): - dataset.load() - - def test_save_delta_format(self, mocker): - dataset = SparkHiveDataSet( - database="default_1", table="delta_table", save_args={"format": "delta"} - ) - mocked_save = mocker.patch("pyspark.sql.DataFrameWriter.saveAsTable") - dataset.save(_generate_spark_df_one()) - mocked_save.assert_called_with( - "default_1.delta_table", mode="errorifexists", format="delta" - ) - assert dataset._format == "delta" diff --git a/tests/extras/datasets/spark/test_spark_jdbc_dataset.py b/tests/extras/datasets/spark/test_spark_jdbc_dataset.py deleted file mode 100644 index 6d89251fc5..0000000000 --- a/tests/extras/datasets/spark/test_spark_jdbc_dataset.py +++ /dev/null @@ -1,113 +0,0 @@ -import pytest - -from kedro.extras.datasets.spark import SparkJDBCDataSet -from kedro.io import DatasetError - - -@pytest.fixture -def spark_jdbc_args(): - return {"url": "dummy_url", "table": "dummy_table"} - - -@pytest.fixture -def spark_jdbc_args_credentials(spark_jdbc_args): - args = spark_jdbc_args - args.update({"credentials": {"user": "dummy_user", "password": "dummy_pw"}}) - return args - - -@pytest.fixture -def spark_jdbc_args_credentials_with_none_password(spark_jdbc_args): - args = spark_jdbc_args - args.update({"credentials": {"user": "dummy_user", "password": None}}) - return args - - -@pytest.fixture -def spark_jdbc_args_save_load(spark_jdbc_args): - args = spark_jdbc_args - connection_properties = {"properties": {"driver": "dummy_driver"}} - args.update( - {"save_args": connection_properties, "load_args": connection_properties} - ) - return args - - -def test_missing_url(): - error_message = ( - "'url' argument cannot be empty. Please provide a JDBC" - " URL of the form 'jdbc:subprotocol:subname'." - ) - with pytest.raises(DatasetError, match=error_message): - SparkJDBCDataSet(url=None, table="dummy_table") - - -def test_missing_table(): - error_message = ( - "'table' argument cannot be empty. Please provide" - " the name of the table to load or save data to." - ) - with pytest.raises(DatasetError, match=error_message): - SparkJDBCDataSet(url="dummy_url", table=None) - - -def test_save(mocker, spark_jdbc_args): - mock_data = mocker.Mock() - data_set = SparkJDBCDataSet(**spark_jdbc_args) - data_set.save(mock_data) - mock_data.write.jdbc.assert_called_with("dummy_url", "dummy_table") - - -def test_save_credentials(mocker, spark_jdbc_args_credentials): - mock_data = mocker.Mock() - data_set = SparkJDBCDataSet(**spark_jdbc_args_credentials) - data_set.save(mock_data) - mock_data.write.jdbc.assert_called_with( - "dummy_url", - "dummy_table", - properties={"user": "dummy_user", "password": "dummy_pw"}, - ) - - -def test_save_args(mocker, spark_jdbc_args_save_load): - mock_data = mocker.Mock() - data_set = SparkJDBCDataSet(**spark_jdbc_args_save_load) - data_set.save(mock_data) - mock_data.write.jdbc.assert_called_with( - "dummy_url", "dummy_table", properties={"driver": "dummy_driver"} - ) - - -def test_except_bad_credentials(mocker, spark_jdbc_args_credentials_with_none_password): - pattern = r"Credential property 'password' cannot be None(.+)" - with pytest.raises(DatasetError, match=pattern): - mock_data = mocker.Mock() - data_set = SparkJDBCDataSet(**spark_jdbc_args_credentials_with_none_password) - data_set.save(mock_data) - - -def test_load(mocker, spark_jdbc_args): - spark = mocker.patch.object(SparkJDBCDataSet, "_get_spark").return_value - data_set = SparkJDBCDataSet(**spark_jdbc_args) - data_set.load() - spark.read.jdbc.assert_called_with("dummy_url", "dummy_table") - - -def test_load_credentials(mocker, spark_jdbc_args_credentials): - spark = mocker.patch.object(SparkJDBCDataSet, "_get_spark").return_value - data_set = SparkJDBCDataSet(**spark_jdbc_args_credentials) - data_set.load() - spark.read.jdbc.assert_called_with( - "dummy_url", - "dummy_table", - properties={"user": "dummy_user", "password": "dummy_pw"}, - ) - - -def test_load_args(mocker, spark_jdbc_args_save_load): - spark = mocker.patch.object(SparkJDBCDataSet, "_get_spark").return_value - data_set = SparkJDBCDataSet(**spark_jdbc_args_save_load) - data_set.load() - spark.read.jdbc.assert_called_with( - "dummy_url", "dummy_table", properties={"driver": "dummy_driver"} - ) diff --git a/tests/extras/datasets/tensorflow/__init__.py b/tests/extras/datasets/tensorflow/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/tensorflow/test_tensorflow_model_dataset.py b/tests/extras/datasets/tensorflow/test_tensorflow_model_dataset.py deleted file mode 100644 index 69c5c46149..0000000000 --- a/tests/extras/datasets/tensorflow/test_tensorflow_model_dataset.py +++ /dev/null @@ -1,441 +0,0 @@ -# pylint: disable=import-outside-toplevel -from pathlib import PurePosixPath - -import numpy as np -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from s3fs import S3FileSystem - -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version - - -# In this test module, we wrap tensorflow and TensorFlowModelDataset imports into a module-scoped -# fixtures to avoid them being evaluated immediately when a new test process is spawned. -# Specifically: -# - ParallelRunner spawns a new subprocess. -# - pytest coverage is initialised on every new subprocess to update the global coverage -# statistics. -# - Coverage has to import the tests including tensorflow tests, which then import tensorflow. -# - tensorflow in eager mode triggers the remove_function method in -# tensorflow/python/eager/context.py, which acquires a threading.Lock. -# - Using a mutex/condition variable after fork (from the child process) is unsafe: -# it can lead to deadlocks" and can lead to segfault. -# -# So tl;dr is pytest-coverage importing of tensorflow creates a potential deadlock within -# a subprocess spawned by the parallel runner, so we wrap the import inside fixtures. -@pytest.fixture(scope="module") -def tf(): - import tensorflow as tf - - return tf - - -@pytest.fixture(scope="module") -def tensorflow_model_dataset(): - from kedro.extras.datasets.tensorflow import TensorFlowModelDataset - - return TensorFlowModelDataset - - -@pytest.fixture -def filepath(tmp_path): - return (tmp_path / "test_tf").as_posix() - - -@pytest.fixture -def dummy_x_train(): - return np.array([[[1.0], [1.0]], [[0.0], [0.0]]]) - - -@pytest.fixture -def dummy_y_train(): - return np.array([[[1], [1]], [[1], [1]]]) - - -@pytest.fixture -def dummy_x_test(): - return np.array([[[0.0], [0.0]], [[1.0], [1.0]]]) - - -@pytest.fixture -def tf_model_dataset(filepath, load_args, save_args, fs_args, tensorflow_model_dataset): - return tensorflow_model_dataset( - filepath=filepath, load_args=load_args, save_args=save_args, fs_args=fs_args - ) - - -@pytest.fixture -def versioned_tf_model_dataset( - filepath, load_version, save_version, tensorflow_model_dataset -): - return tensorflow_model_dataset( - filepath=filepath, version=Version(load_version, save_version) - ) - - -@pytest.fixture -def dummy_tf_base_model(dummy_x_train, dummy_y_train, tf): - # dummy 1 layer model as used in TF tests, see - # https://github.com/tensorflow/tensorflow/blob/8de272b3f3b73bea8d947c5f15143a9f1cfcfc6f/tensorflow/python/keras/models_test.py#L342 - inputs = tf.keras.Input(shape=(2, 1)) - x = tf.keras.layers.Dense(1)(inputs) - outputs = tf.keras.layers.Dense(1)(x) - - model = tf.keras.Model(inputs=inputs, outputs=outputs, name="1_layer_dummy") - model.compile("rmsprop", "mse") - model.fit(dummy_x_train, dummy_y_train, batch_size=64, epochs=1) - # from https://www.tensorflow.org/guide/keras/save_and_serialize - # Reset metrics before saving so that loaded model has same state, - # since metric states are not preserved by Model.save_weights - model.reset_metrics() - return model - - -@pytest.fixture -def dummy_tf_base_model_new(dummy_x_train, dummy_y_train, tf): - # dummy 2 layer model - inputs = tf.keras.Input(shape=(2, 1)) - x = tf.keras.layers.Dense(1)(inputs) - x = tf.keras.layers.Dense(1)(x) - outputs = tf.keras.layers.Dense(1)(x) - - model = tf.keras.Model(inputs=inputs, outputs=outputs, name="2_layer_dummy") - model.compile("rmsprop", "mse") - model.fit(dummy_x_train, dummy_y_train, batch_size=64, epochs=1) - # from https://www.tensorflow.org/guide/keras/save_and_serialize - # Reset metrics before saving so that loaded model has same state, - # since metric states are not preserved by Model.save_weights - model.reset_metrics() - return model - - -@pytest.fixture -def dummy_tf_subclassed_model(dummy_x_train, dummy_y_train, tf): - """Demonstrate that own class models cannot be saved - using HDF5 format but can using TF format - """ - - class MyModel(tf.keras.Model): - def __init__(self): - super().__init__() - self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) - self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) - - # pylint: disable=unused-argument - def call(self, inputs, training=None, mask=None): # pragma: no cover - x = self.dense1(inputs) - return self.dense2(x) - - model = MyModel() - model.compile("rmsprop", "mse") - model.fit(dummy_x_train, dummy_y_train, batch_size=64, epochs=1) - return model - - -class TestTensorFlowModelDataset: - """No versioning passed to creator""" - - def test_save_and_load(self, tf_model_dataset, dummy_tf_base_model, dummy_x_test): - """Test saving and reloading the data set.""" - predictions = dummy_tf_base_model.predict(dummy_x_test) - tf_model_dataset.save(dummy_tf_base_model) - - reloaded = tf_model_dataset.load() - new_predictions = reloaded.predict(dummy_x_test) - np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6) - - assert tf_model_dataset._load_args == {} - assert tf_model_dataset._save_args == {"save_format": "tf"} - - def test_load_missing_model(self, tf_model_dataset): - """Test error message when trying to load missing model.""" - pattern = ( - r"Failed while loading data from data set TensorFlowModelDataset\(.*\)" - ) - with pytest.raises(DatasetError, match=pattern): - tf_model_dataset.load() - - def test_exists(self, tf_model_dataset, dummy_tf_base_model): - """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not tf_model_dataset.exists() - tf_model_dataset.save(dummy_tf_base_model) - assert tf_model_dataset.exists() - - def test_hdf5_save_format( - self, dummy_tf_base_model, dummy_x_test, filepath, tensorflow_model_dataset - ): - """Test TensorflowModelDataset can save TF graph models in HDF5 format""" - hdf5_dataset = tensorflow_model_dataset( - filepath=filepath, save_args={"save_format": "h5"} - ) - - predictions = dummy_tf_base_model.predict(dummy_x_test) - hdf5_dataset.save(dummy_tf_base_model) - - reloaded = hdf5_dataset.load() - new_predictions = reloaded.predict(dummy_x_test) - np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6) - - def test_unused_subclass_model_hdf5_save_format( - self, - dummy_tf_subclassed_model, - dummy_x_train, - dummy_y_train, - dummy_x_test, - filepath, - tensorflow_model_dataset, - ): - """Test TensorflowModelDataset cannot save subclassed user models in HDF5 format - - Subclassed model - - From TF docs - First of all, a subclassed model that has never been used cannot be saved. - That's because a subclassed model needs to be called on some data in order to - create its weights. - """ - hdf5_data_set = tensorflow_model_dataset( - filepath=filepath, save_args={"save_format": "h5"} - ) - # demonstrating is a working model - dummy_tf_subclassed_model.fit( - dummy_x_train, dummy_y_train, batch_size=64, epochs=1 - ) - dummy_tf_subclassed_model.predict(dummy_x_test) - pattern = ( - r"Saving the model to HDF5 format requires the model to be a Functional model or a " - r"Sequential model. It does not work for subclassed models, because such models are " - r"defined via the body of a Python method, which isn\'t safely serializable. Consider " - r"saving to the Tensorflow SavedModel format \(by setting save_format=\"tf\"\) " - r"or using `save_weights`." - ) - with pytest.raises(DatasetError, match=pattern): - hdf5_data_set.save(dummy_tf_subclassed_model) - - @pytest.mark.parametrize( - "filepath,instance_type", - [ - ("s3://bucket/test_tf", S3FileSystem), - ("file:///tmp/test_tf", LocalFileSystem), - ("/tmp/test_tf", LocalFileSystem), - ("gcs://bucket/test_tf", GCSFileSystem), - ("https://example.com/test_tf", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, filepath, instance_type, tensorflow_model_dataset): - """Test that can be instantiated with mocked arbitrary file systems.""" - data_set = tensorflow_model_dataset(filepath=filepath) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - @pytest.mark.parametrize( - "load_args", [{"k1": "v1", "compile": False}], indirect=True - ) - def test_load_extra_params(self, tf_model_dataset, load_args): - """Test overriding the default load arguments.""" - for key, value in load_args.items(): - assert tf_model_dataset._load_args[key] == value - - def test_catalog_release(self, mocker, tensorflow_model_dataset): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.tf" - data_set = tensorflow_model_dataset(filepath=filepath) - assert data_set._version_cache.currsize == 0 # no cache if unversioned - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - assert data_set._version_cache.currsize == 0 - - @pytest.mark.parametrize("fs_args", [{"storage_option": "value"}]) - def test_fs_args(self, fs_args, mocker, tensorflow_model_dataset): - fs_mock = mocker.patch("fsspec.filesystem") - tensorflow_model_dataset("test.tf", fs_args=fs_args) - - fs_mock.assert_called_once_with("file", auto_mkdir=True, storage_option="value") - - def test_exists_with_exception(self, tf_model_dataset, mocker): - """Test `exists` method invocation when `get_filepath_str` raises an exception.""" - mocker.patch("kedro.io.core.get_filepath_str", side_effect=DatasetError) - assert not tf_model_dataset.exists() - - def test_save_and_overwrite_existing_model( - self, tf_model_dataset, dummy_tf_base_model, dummy_tf_base_model_new - ): - """Test models are correcty overwritten.""" - tf_model_dataset.save(dummy_tf_base_model) - - tf_model_dataset.save(dummy_tf_base_model_new) - - reloaded = tf_model_dataset.load() - - assert len(dummy_tf_base_model.layers) != len(reloaded.layers) - assert len(dummy_tf_base_model_new.layers) == len(reloaded.layers) - - -class TestTensorFlowModelDatasetVersioned: - """Test suite with versioning argument passed into TensorFlowModelDataset creator""" - - @pytest.mark.parametrize( - "load_version,save_version", - [ - ( - "2019-01-01T23.59.59.999Z", - "2019-01-01T23.59.59.999Z", - ), # long version names can fail on Win machines due to 260 max filepath - ( - None, - None, - ), # passing None default behaviour of generating timestamp for current time - ], - indirect=True, - ) - def test_save_and_load( - self, - dummy_tf_base_model, - versioned_tf_model_dataset, - dummy_x_test, - load_version, - save_version, - ): # pylint: disable=unused-argument - """Test saving and reloading the versioned data set.""" - - predictions = dummy_tf_base_model.predict(dummy_x_test) - versioned_tf_model_dataset.save(dummy_tf_base_model) - - reloaded = versioned_tf_model_dataset.load() - new_predictions = reloaded.predict(dummy_x_test) - np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6) - - def test_hdf5_save_format( - self, - dummy_tf_base_model, - dummy_x_test, - filepath, - tensorflow_model_dataset, - load_version, - save_version, - ): - """Test versioned TensorflowModelDataset can save TF graph models in - HDF5 format""" - hdf5_dataset = tensorflow_model_dataset( - filepath=filepath, - save_args={"save_format": "h5"}, - version=Version(load_version, save_version), - ) - - predictions = dummy_tf_base_model.predict(dummy_x_test) - hdf5_dataset.save(dummy_tf_base_model) - - reloaded = hdf5_dataset.load() - new_predictions = reloaded.predict(dummy_x_test) - np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6) - - def test_prevent_overwrite(self, dummy_tf_base_model, versioned_tf_model_dataset): - """Check the error when attempting to override the data set if the - corresponding file for a given save version already exists.""" - versioned_tf_model_dataset.save(dummy_tf_base_model) - pattern = ( - r"Save path \'.+\' for TensorFlowModelDataset\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_tf_model_dataset.save(dummy_tf_base_model) - - @pytest.mark.parametrize( - "load_version,save_version", - [("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z")], - indirect=True, - ) - def test_save_version_warning( - self, - versioned_tf_model_dataset, - load_version, - save_version, - dummy_tf_base_model, - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match load version '{load_version}' " - rf"for TensorFlowModelDataset\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_tf_model_dataset.save(dummy_tf_base_model) - - def test_http_filesystem_no_versioning(self, tensorflow_model_dataset): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - tensorflow_model_dataset( - filepath="https://example.com/file.tf", version=Version(None, None) - ) - - def test_exists(self, versioned_tf_model_dataset, dummy_tf_base_model): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_tf_model_dataset.exists() - versioned_tf_model_dataset.save(dummy_tf_base_model) - assert versioned_tf_model_dataset.exists() - - def test_no_versions(self, versioned_tf_model_dataset): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for TensorFlowModelDataset\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_tf_model_dataset.load() - - def test_version_str_repr(self, tf_model_dataset, versioned_tf_model_dataset): - """Test that version is in string representation of the class instance - when applicable.""" - - assert str(tf_model_dataset._filepath) in str(tf_model_dataset) - assert "version=" not in str(tf_model_dataset) - assert "protocol" in str(tf_model_dataset) - assert "save_args" in str(tf_model_dataset) - - assert str(versioned_tf_model_dataset._filepath) in str( - versioned_tf_model_dataset - ) - ver_str = f"version={versioned_tf_model_dataset._version}" - assert ver_str in str(versioned_tf_model_dataset) - assert "protocol" in str(versioned_tf_model_dataset) - assert "save_args" in str(versioned_tf_model_dataset) - - def test_versioning_existing_dataset( - self, tf_model_dataset, versioned_tf_model_dataset, dummy_tf_base_model - ): - """Check behavior when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset. Note: because TensorFlowModelDataset - saves to a directory even if non-versioned, an error is not expected.""" - tf_model_dataset.save(dummy_tf_base_model) - assert tf_model_dataset.exists() - assert tf_model_dataset._filepath == versioned_tf_model_dataset._filepath - versioned_tf_model_dataset.save(dummy_tf_base_model) - assert versioned_tf_model_dataset.exists() - - def test_save_and_load_with_device( - self, - dummy_tf_base_model, - dummy_x_test, - filepath, - tensorflow_model_dataset, - load_version, - save_version, - ): - """Test versioned TensorflowModelDataset can load models using an explicit tf_device""" - hdf5_dataset = tensorflow_model_dataset( - filepath=filepath, - load_args={"tf_device": "/CPU:0"}, - version=Version(load_version, save_version), - ) - - predictions = dummy_tf_base_model.predict(dummy_x_test) - hdf5_dataset.save(dummy_tf_base_model) - - reloaded = hdf5_dataset.load() - new_predictions = reloaded.predict(dummy_x_test) - np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6) diff --git a/tests/extras/datasets/text/__init__.py b/tests/extras/datasets/text/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/text/test_text_dataset.py b/tests/extras/datasets/text/test_text_dataset.py deleted file mode 100644 index 1cb866988d..0000000000 --- a/tests/extras/datasets/text/test_text_dataset.py +++ /dev/null @@ -1,187 +0,0 @@ -from pathlib import Path, PurePosixPath - -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.text import TextDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version - -STRING = "Write to text file." - - -@pytest.fixture -def filepath_txt(tmp_path): - return (tmp_path / "test.txt").as_posix() - - -@pytest.fixture -def txt_data_set(filepath_txt, fs_args): - return TextDataSet(filepath=filepath_txt, fs_args=fs_args) - - -@pytest.fixture -def versioned_txt_data_set(filepath_txt, load_version, save_version): - return TextDataSet( - filepath=filepath_txt, version=Version(load_version, save_version) - ) - - -class TestTextDataSet: - def test_save_and_load(self, txt_data_set): - """Test saving and reloading the data set.""" - txt_data_set.save(STRING) - reloaded = txt_data_set.load() - assert STRING == reloaded - assert txt_data_set._fs_open_args_load == {"mode": "r"} - assert txt_data_set._fs_open_args_save == {"mode": "w"} - - def test_exists(self, txt_data_set): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not txt_data_set.exists() - txt_data_set.save(STRING) - assert txt_data_set.exists() - - @pytest.mark.parametrize( - "fs_args", - [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], - indirect=True, - ) - def test_open_extra_args(self, txt_data_set, fs_args): - assert txt_data_set._fs_open_args_load == fs_args["open_args_load"] - assert txt_data_set._fs_open_args_save == {"mode": "w"} # default unchanged - - def test_load_missing_file(self, txt_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set TextDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - txt_data_set.load() - - @pytest.mark.parametrize( - "filepath,instance_type", - [ - ("s3://bucket/file.txt", S3FileSystem), - ("file:///tmp/test.txt", LocalFileSystem), - ("/tmp/test.txt", LocalFileSystem), - ("gcs://bucket/file.txt", GCSFileSystem), - ("https://example.com/file.txt", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, filepath, instance_type): - data_set = TextDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.txt" - data_set = TextDataSet(filepath=filepath) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - -class TestTextDataSetVersioned: - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "test.txt" - ds = TextDataSet(filepath=filepath) - ds_versioned = TextDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - assert filepath in str(ds) - assert "version" not in str(ds) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "TextDataSet" in str(ds_versioned) - assert "TextDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - - def test_save_and_load(self, versioned_txt_data_set): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_txt_data_set.save(STRING) - reloaded_df = versioned_txt_data_set.load() - assert STRING == reloaded_df - - def test_no_versions(self, versioned_txt_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for TextDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_txt_data_set.load() - - def test_exists(self, versioned_txt_data_set): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_txt_data_set.exists() - versioned_txt_data_set.save(STRING) - assert versioned_txt_data_set.exists() - - def test_prevent_overwrite(self, versioned_txt_data_set): - """Check the error when attempting to override the data set if the - corresponding text file for a given save version already exists.""" - versioned_txt_data_set.save(STRING) - pattern = ( - r"Save path \'.+\' for TextDataSet\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_txt_data_set.save(STRING) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_txt_data_set, load_version, save_version - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for TextDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_txt_data_set.save(STRING) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - TextDataSet( - filepath="https://example.com/file.txt", version=Version(None, None) - ) - - def test_versioning_existing_dataset( - self, - txt_data_set, - versioned_txt_data_set, - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - txt_data_set.save(STRING) - assert txt_data_set.exists() - assert txt_data_set._filepath == versioned_txt_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_txt_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_txt_data_set.save(STRING) - - # Remove non-versioned dataset and try again - Path(txt_data_set._filepath.as_posix()).unlink() - versioned_txt_data_set.save(STRING) - assert versioned_txt_data_set.exists() diff --git a/tests/extras/datasets/tracking/__init__.py b/tests/extras/datasets/tracking/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/tracking/test_json_dataset.py b/tests/extras/datasets/tracking/test_json_dataset.py deleted file mode 100644 index 9e0c046558..0000000000 --- a/tests/extras/datasets/tracking/test_json_dataset.py +++ /dev/null @@ -1,185 +0,0 @@ -import json -from pathlib import Path, PurePosixPath - -import pytest -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.tracking import JSONDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version - - -@pytest.fixture -def filepath_json(tmp_path): - return (tmp_path / "test.json").as_posix() - - -@pytest.fixture -def json_dataset(filepath_json, save_args, fs_args): - return JSONDataSet(filepath=filepath_json, save_args=save_args, fs_args=fs_args) - - -@pytest.fixture -def explicit_versioned_json_dataset(filepath_json, load_version, save_version): - return JSONDataSet( - filepath=filepath_json, version=Version(load_version, save_version) - ) - - -@pytest.fixture -def dummy_data(): - return {"col1": 1, "col2": 2, "col3": "mystring"} - - -class TestJSONDataSet: - def test_save(self, filepath_json, dummy_data, tmp_path, save_version): - """Test saving and reloading the data set.""" - json_dataset = JSONDataSet( - filepath=filepath_json, version=Version(None, save_version) - ) - json_dataset.save(dummy_data) - - actual_filepath = Path(json_dataset._filepath.as_posix()) - test_filepath = tmp_path / "locally_saved.json" - - test_filepath.parent.mkdir(parents=True, exist_ok=True) - with open(test_filepath, "w", encoding="utf-8") as file: - json.dump(dummy_data, file) - - with open(test_filepath, encoding="utf-8") as file: - test_data = json.load(file) - - with open( - (actual_filepath / save_version / "test.json"), encoding="utf-8" - ) as actual_file: - actual_data = json.load(actual_file) - - assert actual_data == test_data - assert json_dataset._fs_open_args_load == {} - assert json_dataset._fs_open_args_save == {"mode": "w"} - - def test_load_fail(self, json_dataset, dummy_data): - json_dataset.save(dummy_data) - pattern = r"Loading not supported for 'JSONDataSet'" - with pytest.raises(DatasetError, match=pattern): - json_dataset.load() - - def test_exists(self, json_dataset, dummy_data): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not json_dataset.exists() - json_dataset.save(dummy_data) - assert json_dataset.exists() - - @pytest.mark.parametrize( - "save_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_save_extra_params(self, json_dataset, save_args): - """Test overriding the default save arguments.""" - for key, value in save_args.items(): - assert json_dataset._save_args[key] == value - - @pytest.mark.parametrize( - "fs_args", - [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], - indirect=True, - ) - def test_open_extra_args(self, json_dataset, fs_args): - assert json_dataset._fs_open_args_load == fs_args["open_args_load"] - assert json_dataset._fs_open_args_save == {"mode": "w"} # default unchanged - - @pytest.mark.parametrize( - "filepath,instance_type", - [ - ("s3://bucket/file.json", S3FileSystem), - ("file:///tmp/test.json", LocalFileSystem), - ("/tmp/test.json", LocalFileSystem), - ("gcs://bucket/file.json", GCSFileSystem), - ], - ) - def test_protocol_usage(self, filepath, instance_type): - data_set = JSONDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.json" - data_set = JSONDataSet(filepath=filepath) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - def test_not_version_str_repr(self): - """Test that version is not in string representation of the class instance.""" - filepath = "test.json" - ds = JSONDataSet(filepath=filepath) - - assert filepath in str(ds) - assert "version" not in str(ds) - assert "JSONDataSet" in str(ds) - assert "protocol" in str(ds) - # Default save_args - assert "save_args={'indent': 2}" in str(ds) - - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance.""" - filepath = "test.json" - ds_versioned = JSONDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "JSONDataSet" in str(ds_versioned) - assert "protocol" in str(ds_versioned) - # Default save_args - assert "save_args={'indent': 2}" in str(ds_versioned) - - def test_prevent_overwrite(self, explicit_versioned_json_dataset, dummy_data): - """Check the error when attempting to override the data set if the - corresponding json file for a given save version already exists.""" - explicit_versioned_json_dataset.save(dummy_data) - pattern = ( - r"Save path \'.+\' for JSONDataSet\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - explicit_versioned_json_dataset.save(dummy_data) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, - explicit_versioned_json_dataset, - load_version, - save_version, - dummy_data, - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - f"Save version '{save_version}' did not match " - f"load version '{load_version}' for " - r"JSONDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - explicit_versioned_json_dataset.save(dummy_data) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - JSONDataSet( - filepath="https://example.com/file.json", version=Version(None, None) - ) diff --git a/tests/extras/datasets/tracking/test_metrics_dataset.py b/tests/extras/datasets/tracking/test_metrics_dataset.py deleted file mode 100644 index d65b50215d..0000000000 --- a/tests/extras/datasets/tracking/test_metrics_dataset.py +++ /dev/null @@ -1,194 +0,0 @@ -import json -from pathlib import Path, PurePosixPath - -import pytest -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.tracking import MetricsDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version - - -@pytest.fixture -def filepath_json(tmp_path): - return (tmp_path / "test.json").as_posix() - - -@pytest.fixture -def metrics_dataset(filepath_json, save_args, fs_args): - return MetricsDataSet(filepath=filepath_json, save_args=save_args, fs_args=fs_args) - - -@pytest.fixture -def explicit_versioned_metrics_dataset(filepath_json, load_version, save_version): - return MetricsDataSet( - filepath=filepath_json, version=Version(load_version, save_version) - ) - - -@pytest.fixture -def dummy_data(): - return {"col1": 1, "col2": 2, "col3": 3} - - -class TestMetricsDataSet: - def test_save_data( - self, - dummy_data, - tmp_path, - filepath_json, - save_version, - ): - """Test saving and reloading the data set.""" - metrics_dataset = MetricsDataSet( - filepath=filepath_json, version=Version(None, save_version) - ) - metrics_dataset.save(dummy_data) - - actual_filepath = Path(metrics_dataset._filepath.as_posix()) - test_filepath = tmp_path / "locally_saved.json" - - test_filepath.parent.mkdir(parents=True, exist_ok=True) - with open(test_filepath, "w", encoding="utf-8") as file: - json.dump(dummy_data, file) - - with open(test_filepath, encoding="utf-8") as file: - test_data = json.load(file) - - with open( - (actual_filepath / save_version / "test.json"), encoding="utf-8" - ) as actual_file: - actual_data = json.load(actual_file) - - assert actual_data == test_data - assert metrics_dataset._fs_open_args_load == {} - assert metrics_dataset._fs_open_args_save == {"mode": "w"} - - def test_load_fail(self, metrics_dataset, dummy_data): - metrics_dataset.save(dummy_data) - pattern = r"Loading not supported for 'MetricsDataSet'" - with pytest.raises(DatasetError, match=pattern): - metrics_dataset.load() - - def test_exists(self, metrics_dataset, dummy_data): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not metrics_dataset.exists() - metrics_dataset.save(dummy_data) - assert metrics_dataset.exists() - - @pytest.mark.parametrize( - "save_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_save_extra_params(self, metrics_dataset, save_args): - """Test overriding the default save arguments.""" - for key, value in save_args.items(): - assert metrics_dataset._save_args[key] == value - - @pytest.mark.parametrize( - "fs_args", - [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], - indirect=True, - ) - def test_open_extra_args(self, metrics_dataset, fs_args): - assert metrics_dataset._fs_open_args_load == fs_args["open_args_load"] - assert metrics_dataset._fs_open_args_save == {"mode": "w"} # default unchanged - - @pytest.mark.parametrize( - "filepath,instance_type", - [ - ("s3://bucket/file.json", S3FileSystem), - ("file:///tmp/test.json", LocalFileSystem), - ("/tmp/test.json", LocalFileSystem), - ("gcs://bucket/file.json", GCSFileSystem), - ], - ) - def test_protocol_usage(self, filepath, instance_type): - data_set = MetricsDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.json" - data_set = MetricsDataSet(filepath=filepath) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - def test_fail_on_saving_non_numeric_value(self, metrics_dataset): - data = {"col1": 1, "col2": 2, "col3": "hello"} - - pattern = "The MetricsDataSet expects only numeric values." - with pytest.raises(DatasetError, match=pattern): - metrics_dataset.save(data) - - def test_not_version_str_repr(self): - """Test that version is not in string representation of the class instance.""" - filepath = "test.json" - ds = MetricsDataSet(filepath=filepath) - - assert filepath in str(ds) - assert "version" not in str(ds) - assert "MetricsDataSet" in str(ds) - assert "protocol" in str(ds) - # Default save_args - assert "save_args={'indent': 2}" in str(ds) - - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance.""" - filepath = "test.json" - ds_versioned = MetricsDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "MetricsDataSet" in str(ds_versioned) - assert "protocol" in str(ds_versioned) - # Default save_args - assert "save_args={'indent': 2}" in str(ds_versioned) - - def test_prevent_overwrite(self, explicit_versioned_metrics_dataset, dummy_data): - """Check the error when attempting to override the data set if the - corresponding json file for a given save version already exists.""" - explicit_versioned_metrics_dataset.save(dummy_data) - pattern = ( - r"Save path \'.+\' for MetricsDataSet\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - explicit_versioned_metrics_dataset.save(dummy_data) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, explicit_versioned_metrics_dataset, load_version, save_version, dummy_data - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - f"Save version '{save_version}' did not match " - f"load version '{load_version}' for " - r"MetricsDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - explicit_versioned_metrics_dataset.save(dummy_data) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - MetricsDataSet( - filepath="https://example.com/file.json", version=Version(None, None) - ) diff --git a/tests/extras/datasets/video/conftest.py b/tests/extras/datasets/video/conftest.py deleted file mode 100644 index ff084cdb5e..0000000000 --- a/tests/extras/datasets/video/conftest.py +++ /dev/null @@ -1,107 +0,0 @@ -from pathlib import Path - -import pytest -from PIL import Image -from utils import TEST_FPS, TEST_HEIGHT, TEST_WIDTH - -from kedro.extras.datasets.video.video_dataset import ( - FileVideo, - GeneratorVideo, - SequenceVideo, -) - - -@pytest.fixture(scope="module") -def red_frame(): - return Image.new("RGB", (TEST_WIDTH, TEST_HEIGHT), (255, 0, 0)) - - -@pytest.fixture(scope="module") -def green_frame(): - return Image.new("RGB", (TEST_WIDTH, TEST_HEIGHT), (0, 255, 0)) - - -@pytest.fixture(scope="module") -def blue_frame(): - return Image.new("RGB", (TEST_WIDTH, TEST_HEIGHT), (0, 0, 255)) - - -@pytest.fixture(scope="module") -def yellow_frame(): - return Image.new("RGB", (TEST_WIDTH, TEST_HEIGHT), (255, 255, 0)) - - -@pytest.fixture(scope="module") -def purple_frame(): - return Image.new("RGB", (TEST_WIDTH, TEST_HEIGHT), (255, 0, 255)) - - -@pytest.fixture -def color_video(red_frame, green_frame, blue_frame, yellow_frame, purple_frame): - return SequenceVideo( - [red_frame, green_frame, blue_frame, yellow_frame, purple_frame], - fps=TEST_FPS, - ) - - -@pytest.fixture -def color_video_generator( - red_frame, green_frame, blue_frame, yellow_frame, purple_frame -): - sequence = [red_frame, green_frame, blue_frame, yellow_frame, purple_frame] - - def generator(): - yield from sequence - - return GeneratorVideo( - generator(), - length=len(sequence), - fps=TEST_FPS, - ) - - -@pytest.fixture -def filepath_mp4(): - """This is a real video converted to mp4/h264 with ffmpeg command""" - return str(Path(__file__).parent / "data/video.mp4") - - -@pytest.fixture -def filepath_mkv(): - """This a a real video recoreded with an Axis network camera""" - return str(Path(__file__).parent / "data/video.mkv") - - -@pytest.fixture -def filepath_mjpeg(): - """This is a real video recorded with an Axis network camera""" - return str(Path(__file__).parent / "data/video.mjpeg") - - -@pytest.fixture -def filepath_color_mp4(): - """This is a video created with the OpenCV VideoWriter - - it contains 5 frames which each is a single color: red, green, blue, yellow, purple - """ - return str(Path(__file__).parent / "data/color_video.mp4") - - -@pytest.fixture -def mp4_object(filepath_mp4): - return FileVideo(filepath_mp4) - - -@pytest.fixture -def mkv_object(filepath_mkv): - return FileVideo(filepath_mkv) - - -@pytest.fixture -def mjpeg_object(filepath_mjpeg): - return FileVideo(filepath_mjpeg) - - -@pytest.fixture -def color_video_object(filepath_color_mp4): - return FileVideo(filepath_color_mp4) diff --git a/tests/extras/datasets/video/data/color_video.mp4 b/tests/extras/datasets/video/data/color_video.mp4 deleted file mode 100644 index 01944b1b78..0000000000 Binary files a/tests/extras/datasets/video/data/color_video.mp4 and /dev/null differ diff --git a/tests/extras/datasets/video/data/video.mjpeg b/tests/extras/datasets/video/data/video.mjpeg deleted file mode 100644 index cab90dda94..0000000000 Binary files a/tests/extras/datasets/video/data/video.mjpeg and /dev/null differ diff --git a/tests/extras/datasets/video/data/video.mkv b/tests/extras/datasets/video/data/video.mkv deleted file mode 100644 index 2710c022ff..0000000000 Binary files a/tests/extras/datasets/video/data/video.mkv and /dev/null differ diff --git a/tests/extras/datasets/video/data/video.mp4 b/tests/extras/datasets/video/data/video.mp4 deleted file mode 100644 index 4c4b974d92..0000000000 Binary files a/tests/extras/datasets/video/data/video.mp4 and /dev/null differ diff --git a/tests/extras/datasets/video/test_sliced_video.py b/tests/extras/datasets/video/test_sliced_video.py deleted file mode 100644 index e2e4975d1a..0000000000 --- a/tests/extras/datasets/video/test_sliced_video.py +++ /dev/null @@ -1,56 +0,0 @@ -import numpy as np -from utils import TEST_HEIGHT, TEST_WIDTH - - -class TestSlicedVideo: - def test_slice_sequence_video_first(self, color_video): - """Test slicing and then indexing a SequenceVideo""" - slice_red_green = color_video[:2] - red = np.array(slice_red_green[0]) - assert red.shape == (TEST_HEIGHT, TEST_WIDTH, 3) - assert np.all(red[:, :, 0] == 255) - assert np.all(red[:, :, 1] == 0) - assert np.all(red[:, :, 2] == 0) - - def test_slice_sequence_video_last_as_index(self, color_video): - """Test slicing and then indexing a SequenceVideo""" - slice_blue_yellow_purple = color_video[2:5] - purple = np.array(slice_blue_yellow_purple[2]) - assert purple.shape == (TEST_HEIGHT, TEST_WIDTH, 3) - assert np.all(purple[:, :, 0] == 255) - assert np.all(purple[:, :, 1] == 0) - assert np.all(purple[:, :, 2] == 255) - - def test_slice_sequence_video_last_as_end(self, color_video): - """Test slicing and then indexing a SequenceVideo""" - slice_blue_yellow_purple = color_video[2:] - purple = np.array(slice_blue_yellow_purple[-1]) - assert purple.shape == (TEST_HEIGHT, TEST_WIDTH, 3) - assert np.all(purple[:, :, 0] == 255) - assert np.all(purple[:, :, 1] == 0) - assert np.all(purple[:, :, 2] == 255) - - def test_slice_sequence_attribute(self, color_video): - """Test that attributes from the base class are reachable from sliced views""" - slice_red_green = color_video[:2] - assert slice_red_green.fps == color_video.fps - - def test_slice_sliced_video(self, color_video): - """Test slicing and then indexing a SlicedVideo""" - slice_green_blue_yellow = color_video[1:4] - slice_green_blue = slice_green_blue_yellow[:-1] - blue = np.array(slice_green_blue[1]) - assert blue.shape == (TEST_HEIGHT, TEST_WIDTH, 3) - assert np.all(blue[:, :, 0] == 0) - assert np.all(blue[:, :, 1] == 0) - assert np.all(blue[:, :, 2] == 255) - - def test_slice_file_video_first(self, mp4_object): - """Test slicing and then indexing a FileVideo""" - sliced_video = mp4_object[:2] - assert np.all(np.array(sliced_video[0]) == np.array(mp4_object[0])) - - def test_slice_file_video_last(self, mp4_object): - """Test slicing and then indexing a FileVideo""" - sliced_video = mp4_object[-2:] - assert np.all(np.array(sliced_video[-1]) == np.array(mp4_object[-1])) diff --git a/tests/extras/datasets/video/test_video_dataset.py b/tests/extras/datasets/video/test_video_dataset.py deleted file mode 100644 index ceeb13929b..0000000000 --- a/tests/extras/datasets/video/test_video_dataset.py +++ /dev/null @@ -1,186 +0,0 @@ -import boto3 -import pytest -from moto import mock_s3 -from utils import TEST_FPS, assert_videos_equal - -from kedro.extras.datasets.video import VideoDataSet -from kedro.extras.datasets.video.video_dataset import FileVideo, SequenceVideo -from kedro.io import DatasetError - -S3_BUCKET_NAME = "test_bucket" -S3_KEY_PATH = "video" -S3_FULL_PATH = f"s3://{S3_BUCKET_NAME}/{S3_KEY_PATH}/" -AWS_CREDENTIALS = {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"} - - -@pytest.fixture -def tmp_filepath_mp4(tmp_path): - return (tmp_path / "test.mp4").as_posix() - - -@pytest.fixture -def tmp_filepath_avi(tmp_path): - return (tmp_path / "test.mjpeg").as_posix() - - -@pytest.fixture -def empty_dataset_mp4(tmp_filepath_mp4): - return VideoDataSet(filepath=tmp_filepath_mp4) - - -@pytest.fixture -def empty_dataset_avi(tmp_filepath_avi): - return VideoDataSet(filepath=tmp_filepath_avi) - - -@pytest.fixture -def mocked_s3_bucket(): - """Create a bucket for testing using moto.""" - with mock_s3(): - conn = boto3.client( - "s3", - region_name="us-east-1", - aws_access_key_id=AWS_CREDENTIALS["key"], - aws_secret_access_key=AWS_CREDENTIALS["secret"], - ) - conn.create_bucket(Bucket=S3_BUCKET_NAME) - yield conn - - -class TestVideoDataSet: - def test_load_mp4(self, filepath_mp4, mp4_object): - """Loading a mp4 dataset should create a FileVideo""" - ds = VideoDataSet(filepath_mp4) - loaded_video = ds.load() - assert_videos_equal(loaded_video, mp4_object) - - def test_save_and_load_mp4(self, empty_dataset_mp4, mp4_object): - """Test saving and reloading the data set.""" - empty_dataset_mp4.save(mp4_object) - reloaded_video = empty_dataset_mp4.load() - assert_videos_equal(mp4_object, reloaded_video) - assert reloaded_video.fourcc == empty_dataset_mp4._fourcc - - @pytest.mark.skip( - reason="Only one available codec that is typically installed when testing" - ) - def test_save_with_other_codec(self, tmp_filepath_mp4, mp4_object): - """Test saving the video with another codec than default.""" - save_fourcc = "xvid" - ds = VideoDataSet(filepath=tmp_filepath_mp4, fourcc=save_fourcc) - ds.save(mp4_object) - reloaded_video = ds.load() - assert reloaded_video.fourcc == save_fourcc - - def test_save_with_derived_codec(self, tmp_filepath_mp4, color_video): - """Test saving video by the codec specified in the video object""" - ds = VideoDataSet(filepath=tmp_filepath_mp4, fourcc=None) - ds.save(color_video) - reloaded_video = ds.load() - assert reloaded_video.fourcc == color_video.fourcc - - def test_saved_fps(self, empty_dataset_mp4, color_video): - """Verify that a saved video has the same framerate as specified in the video object""" - empty_dataset_mp4.save(color_video) - reloaded_video = empty_dataset_mp4.load() - assert reloaded_video.fps == TEST_FPS - - def test_save_sequence_video(self, color_video, empty_dataset_mp4): - """Test save (and load) a SequenceVideo object""" - empty_dataset_mp4.save(color_video) - reloaded_video = empty_dataset_mp4.load() - assert_videos_equal(color_video, reloaded_video) - - def test_save_generator_video( - self, color_video_generator, empty_dataset_mp4, color_video - ): - """Test save (and load) a GeneratorVideo object - - Since the GeneratorVideo is exhaused after saving the video to file we use - the SequenceVideo (color_video) which has the same frames to compare the - loaded video to. - """ - empty_dataset_mp4.save(color_video_generator) - reloaded_video = empty_dataset_mp4.load() - assert_videos_equal(color_video, reloaded_video) - - def test_exists(self, empty_dataset_mp4, mp4_object): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not empty_dataset_mp4.exists() - empty_dataset_mp4.save(mp4_object) - assert empty_dataset_mp4.exists() - - @pytest.mark.skip(reason="Can't deal with videos with missing time info") - def test_convert_video(self, empty_dataset_mp4, mjpeg_object): - """Load a file video in mjpeg format and save in mp4v""" - empty_dataset_mp4.save(mjpeg_object) - reloaded_video = empty_dataset_mp4.load() - assert_videos_equal(mjpeg_object, reloaded_video) - - def test_load_missing_file(self, empty_dataset_mp4): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set VideoDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - empty_dataset_mp4.load() - - def test_save_s3(self, mp4_object, mocked_s3_bucket, tmp_path): - """Test to save a VideoDataSet to S3 storage""" - video_name = "video.mp4" - - dataset = VideoDataSet( - filepath=S3_FULL_PATH + video_name, credentials=AWS_CREDENTIALS - ) - dataset.save(mp4_object) - - tmp_file = tmp_path / video_name - mocked_s3_bucket.download_file( - Bucket=S3_BUCKET_NAME, - Key=S3_KEY_PATH + "/" + video_name, - Filename=str(tmp_file), - ) - reloaded_video = FileVideo(str(tmp_file)) - assert_videos_equal(reloaded_video, mp4_object) - - @pytest.mark.xfail - @pytest.mark.parametrize( - "fourcc, suffix", - [ - ("mp4v", "mp4"), - ("mp4v", "mjpeg"), - ("mp4v", "avi"), - ("avc1", "mp4"), - ("avc1", "mjpeg"), - ("avc1", "avi"), - ("mjpg", "mp4"), - ("mjpg", "mjpeg"), - ("mjpg", "avi"), - ("xvid", "mp4"), - ("xvid", "mjpeg"), - ("xvid", "avi"), - ("x264", "mp4"), - ("x264", "mjpeg"), - ("x264", "avi"), - ("divx", "mp4"), - ("divx", "mjpeg"), - ("divx", "avi"), - ("fmp4", "mp4"), - ("fmp4", "mjpeg"), - ("fmp4", "avi"), - ], - ) - def test_video_codecs(self, fourcc, suffix, color_video): - """Test different codec and container combinations - - Some of these are expected to fail depending on what - codecs are installed on the machine. - """ - video_name = f"video.{suffix}" - video = SequenceVideo(color_video._frames, 25, fourcc) - ds = VideoDataSet(video_name, fourcc=None) - ds.save(video) - # We also need to verify that the correct codec was used - # since OpenCV silently (with a warning in the log) fall backs to - # another codec if one specified is not compatible with the container - reloaded_video = ds.load() - assert reloaded_video.fourcc == fourcc diff --git a/tests/extras/datasets/video/test_video_objects.py b/tests/extras/datasets/video/test_video_objects.py deleted file mode 100644 index 66a284fa60..0000000000 --- a/tests/extras/datasets/video/test_video_objects.py +++ /dev/null @@ -1,170 +0,0 @@ -import numpy as np -import pytest -from utils import ( - DEFAULT_FOURCC, - MJPEG_FOURCC, - MJPEG_FPS, - MJPEG_LEN, - MJPEG_SIZE, - MKV_FOURCC, - MKV_FPS, - MKV_LEN, - MKV_SIZE, - MP4_FOURCC, - MP4_FPS, - MP4_LEN, - MP4_SIZE, - TEST_FPS, - TEST_HEIGHT, - TEST_NUM_COLOR_FRAMES, - TEST_WIDTH, - assert_images_equal, -) - -from kedro.extras.datasets.video.video_dataset import ( - FileVideo, - GeneratorVideo, - SequenceVideo, -) - - -class TestSequenceVideo: - def test_sequence_video_indexing_first(self, color_video, red_frame): - """Test indexing a SequenceVideo""" - red = np.array(color_video[0]) - assert red.shape == (TEST_HEIGHT, TEST_WIDTH, 3) - assert np.all(red == red_frame) - - def test_sequence_video_indexing_last(self, color_video, purple_frame): - """Test indexing a SequenceVideo""" - purple = np.array(color_video[-1]) - assert purple.shape == (TEST_HEIGHT, TEST_WIDTH, 3) - assert np.all(purple == purple_frame) - - def test_sequence_video_iterable(self, color_video): - """Test iterating a SequenceVideo""" - for i, img in enumerate(map(np.array, color_video)): - assert np.all(img == np.array(color_video[i])) - assert i == TEST_NUM_COLOR_FRAMES - 1 - - def test_sequence_video_fps(self, color_video): - # Test the one set by the fixture - assert color_video.fps == TEST_FPS - - # Test creating with another fps - test_fps_new = 123 - color_video_new = SequenceVideo(color_video._frames, fps=test_fps_new) - assert color_video_new.fps == test_fps_new - - def test_sequence_video_len(self, color_video): - assert len(color_video) == TEST_NUM_COLOR_FRAMES - - def test_sequence_video_size(self, color_video): - assert color_video.size == (TEST_WIDTH, TEST_HEIGHT) - - def test_sequence_video_fourcc_default_value(self, color_video): - assert color_video.fourcc == DEFAULT_FOURCC - - def test_sequence_video_fourcc(self, color_video): - fourcc_new = "mjpg" - assert ( - DEFAULT_FOURCC != fourcc_new - ), "Test does not work if new test value is same as default" - color_video_new = SequenceVideo( - color_video._frames, fps=TEST_FPS, fourcc=fourcc_new - ) - assert color_video_new.fourcc == fourcc_new - - -class TestGeneratorVideo: - def test_generator_video_iterable(self, color_video_generator, color_video): - """Test iterating a GeneratorVideo - - The content of the mock GeneratorVideo should be the same as the SequenceVideo, - the content in the later is tested in other unit tests and can thus be trusted - """ - for i, img in enumerate(map(np.array, color_video_generator)): - assert np.all(img == np.array(color_video[i])) - assert i == TEST_NUM_COLOR_FRAMES - 1 - - def test_generator_video_fps(self, color_video_generator): - # Test the one set by the fixture - assert color_video_generator.fps == TEST_FPS - - # Test creating with another fps - test_fps_new = 123 - color_video_new = GeneratorVideo( - color_video_generator._gen, length=TEST_NUM_COLOR_FRAMES, fps=test_fps_new - ) - assert color_video_new.fps == test_fps_new - - def test_generator_video_len(self, color_video_generator): - assert len(color_video_generator) == TEST_NUM_COLOR_FRAMES - - def test_generator_video_size(self, color_video_generator): - assert color_video_generator.size == (TEST_WIDTH, TEST_HEIGHT) - - def test_generator_video_fourcc_default_value(self, color_video_generator): - assert color_video_generator.fourcc == DEFAULT_FOURCC - - def test_generator_video_fourcc(self, color_video_generator): - fourcc_new = "mjpg" - assert ( - DEFAULT_FOURCC != fourcc_new - ), "Test does not work if new test value is same as default" - color_video_new = GeneratorVideo( - color_video_generator._gen, - length=TEST_NUM_COLOR_FRAMES, - fps=TEST_FPS, - fourcc=fourcc_new, - ) - assert color_video_new.fourcc == fourcc_new - - -class TestFileVideo: - @pytest.mark.skip(reason="Can't deal with videos with missing time info") - def test_file_props_mjpeg(self, mjpeg_object): - assert mjpeg_object.fourcc == MJPEG_FOURCC - assert mjpeg_object.fps == MJPEG_FPS - assert mjpeg_object.size == MJPEG_SIZE - assert len(mjpeg_object) == MJPEG_LEN - - def test_file_props_mkv(self, mkv_object): - assert mkv_object.fourcc == MKV_FOURCC - assert mkv_object.fps == MKV_FPS - assert mkv_object.size == MKV_SIZE - assert len(mkv_object) == MKV_LEN - - def test_file_props_mp4(self, mp4_object): - assert mp4_object.fourcc == MP4_FOURCC - assert mp4_object.fps == MP4_FPS - assert mp4_object.size == MP4_SIZE - assert len(mp4_object) == MP4_LEN - - def test_file_index_first(self, color_video_object, red_frame): - assert_images_equal(color_video_object[0], red_frame) - - def test_file_index_last_by_index(self, color_video_object, purple_frame): - assert_images_equal(color_video_object[TEST_NUM_COLOR_FRAMES - 1], purple_frame) - - def test_file_index_last(self, color_video_object, purple_frame): - assert_images_equal(color_video_object[-1], purple_frame) - - def test_file_video_failed_capture(self, mocker): - """Validate good behavior on failed decode - - The best behavior in this case is not obvious, the len property of the - video object specifies more frames than is actually possible to decode. We - cannot know this in advance without spending loads of time to decode all frames - in order to count them.""" - mock_cv2 = mocker.patch("kedro.extras.datasets.video.video_dataset.cv2") - mock_cap = mock_cv2.VideoCapture.return_value = mocker.Mock() - mock_cap.get.return_value = 2 # Set the length of the video - ds = FileVideo("/a/b/c") - - mock_cap.read.return_value = True, np.zeros((1, 1)) - assert ds[0] - - mock_cap.read.return_value = False, None - with pytest.raises(IndexError): - ds[1] diff --git a/tests/extras/datasets/video/utils.py b/tests/extras/datasets/video/utils.py deleted file mode 100644 index 6b675aed2f..0000000000 --- a/tests/extras/datasets/video/utils.py +++ /dev/null @@ -1,49 +0,0 @@ -import itertools - -import numpy as np -from PIL import ImageChops - -TEST_WIDTH = 640 # Arbitrary value for testing -TEST_HEIGHT = 480 # Arbitrary value for testing -TEST_FPS = 1 # Arbitrary value for testing - -TEST_NUM_COLOR_FRAMES = ( - 5 # This should be the same as number of frames in conftest videos -) -DEFAULT_FOURCC = "mp4v" # The expected default fourcc value - -# This is video data extracted from the video files with ffmpeg command -MKV_SIZE = (640, 360) -MKV_FPS = 50 -MKV_FOURCC = "h264" -MKV_LEN = 109 # from ffprobe - -MP4_SIZE = (640, 360) -MP4_FPS = 50 -MP4_FOURCC = "avc1" -MP4_LEN = 109 # from ffprobe - -MJPEG_SIZE = (640, 360) -MJPEG_FPS = 25 # From ffprobe, not reported by ffmpeg command -# I'm not sure that MJPE is the correct fourcc code for -# mjpeg video since I cannot find any official reference to -# that code. This is however what the openCV VideoCapture -# reports for the video, so we leave it like this for now.. -MJPEG_FOURCC = "mjpe" -MJPEG_LEN = 24 # from ffprobe - - -def assert_images_equal(image_1, image_2): - """Assert that two images are approximately equal, allow for some - compression artifacts""" - assert image_1.size == image_2.size - diff = np.asarray(ImageChops.difference(image_1, image_2)) - assert np.mean(diff) < 5 - assert np.mean(diff > 50) < 0.01 # Max 1% of pixels - - -def assert_videos_equal(video_1, video_2): - assert len(video_1) == len(video_2) - - for image_1, image_2 in itertools.zip_longest(video_1, video_2): - assert_images_equal(image_1, image_2) diff --git a/tests/extras/datasets/yaml/__init__.py b/tests/extras/datasets/yaml/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/datasets/yaml/test_yaml_dataset.py b/tests/extras/datasets/yaml/test_yaml_dataset.py deleted file mode 100644 index 432afaed3d..0000000000 --- a/tests/extras/datasets/yaml/test_yaml_dataset.py +++ /dev/null @@ -1,210 +0,0 @@ -from pathlib import Path, PurePosixPath - -import pandas as pd -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from pandas.testing import assert_frame_equal -from s3fs.core import S3FileSystem - -from kedro.extras.datasets.yaml import YAMLDataSet -from kedro.io import DatasetError -from kedro.io.core import PROTOCOL_DELIMITER, Version - - -@pytest.fixture -def filepath_yaml(tmp_path): - return (tmp_path / "test.yaml").as_posix() - - -@pytest.fixture -def yaml_data_set(filepath_yaml, save_args, fs_args): - return YAMLDataSet(filepath=filepath_yaml, save_args=save_args, fs_args=fs_args) - - -@pytest.fixture -def versioned_yaml_data_set(filepath_yaml, load_version, save_version): - return YAMLDataSet( - filepath=filepath_yaml, version=Version(load_version, save_version) - ) - - -@pytest.fixture -def dummy_data(): - return {"col1": 1, "col2": 2, "col3": 3} - - -class TestYAMLDataSet: - def test_save_and_load(self, yaml_data_set, dummy_data): - """Test saving and reloading the data set.""" - yaml_data_set.save(dummy_data) - reloaded = yaml_data_set.load() - assert dummy_data == reloaded - assert yaml_data_set._fs_open_args_load == {} - assert yaml_data_set._fs_open_args_save == {"mode": "w"} - - def test_exists(self, yaml_data_set, dummy_data): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not yaml_data_set.exists() - yaml_data_set.save(dummy_data) - assert yaml_data_set.exists() - - @pytest.mark.parametrize( - "save_args", [{"k1": "v1", "index": "value"}], indirect=True - ) - def test_save_extra_params(self, yaml_data_set, save_args): - """Test overriding the default save arguments.""" - for key, value in save_args.items(): - assert yaml_data_set._save_args[key] == value - - @pytest.mark.parametrize( - "fs_args", - [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], - indirect=True, - ) - def test_open_extra_args(self, yaml_data_set, fs_args): - assert yaml_data_set._fs_open_args_load == fs_args["open_args_load"] - assert yaml_data_set._fs_open_args_save == {"mode": "w"} # default unchanged - - def test_load_missing_file(self, yaml_data_set): - """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set YAMLDataSet\(.*\)" - with pytest.raises(DatasetError, match=pattern): - yaml_data_set.load() - - @pytest.mark.parametrize( - "filepath,instance_type", - [ - ("s3://bucket/file.yaml", S3FileSystem), - ("file:///tmp/test.yaml", LocalFileSystem), - ("/tmp/test.yaml", LocalFileSystem), - ("gcs://bucket/file.yaml", GCSFileSystem), - ("https://example.com/file.yaml", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, filepath, instance_type): - data_set = YAMLDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) - - path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - filepath = "test.yaml" - data_set = YAMLDataSet(filepath=filepath) - data_set.release() - fs_mock.invalidate_cache.assert_called_once_with(filepath) - - def test_dataframe_support(self, yaml_data_set): - data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5]}) - yaml_data_set.save(data.to_dict()) - reloaded = yaml_data_set.load() - assert isinstance(reloaded, dict) - - data_df = pd.DataFrame.from_dict(reloaded) - assert_frame_equal(data, data_df) - - -class TestYAMLDataSetVersioned: - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - filepath = "test.yaml" - ds = YAMLDataSet(filepath=filepath) - ds_versioned = YAMLDataSet( - filepath=filepath, version=Version(load_version, save_version) - ) - assert filepath in str(ds) - assert "version" not in str(ds) - - assert filepath in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "YAMLDataSet" in str(ds_versioned) - assert "YAMLDataSet" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - # Default save_args - assert "save_args={'default_flow_style': False}" in str(ds) - assert "save_args={'default_flow_style': False}" in str(ds_versioned) - - def test_save_and_load(self, versioned_yaml_data_set, dummy_data): - """Test that saved and reloaded data matches the original one for - the versioned data set.""" - versioned_yaml_data_set.save(dummy_data) - reloaded = versioned_yaml_data_set.load() - assert dummy_data == reloaded - - def test_no_versions(self, versioned_yaml_data_set): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for YAMLDataSet\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_yaml_data_set.load() - - def test_exists(self, versioned_yaml_data_set, dummy_data): - """Test `exists` method invocation for versioned data set.""" - assert not versioned_yaml_data_set.exists() - versioned_yaml_data_set.save(dummy_data) - assert versioned_yaml_data_set.exists() - - def test_prevent_overwrite(self, versioned_yaml_data_set, dummy_data): - """Check the error when attempting to override the data set if the - corresponding yaml file for a given save version already exists.""" - versioned_yaml_data_set.save(dummy_data) - pattern = ( - r"Save path \'.+\' for YAMLDataSet\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_yaml_data_set.save(dummy_data) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_yaml_data_set, load_version, save_version, dummy_data - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for YAMLDataSet\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_yaml_data_set.save(dummy_data) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - YAMLDataSet( - filepath="https://example.com/file.yaml", version=Version(None, None) - ) - - def test_versioning_existing_dataset( - self, yaml_data_set, versioned_yaml_data_set, dummy_data - ): - """Check the error when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset.""" - yaml_data_set.save(dummy_data) - assert yaml_data_set.exists() - assert yaml_data_set._filepath == versioned_yaml_data_set._filepath - pattern = ( - f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_yaml_data_set._filepath.parent.as_posix()})" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_yaml_data_set.save(dummy_data) - - # Remove non-versioned dataset and try again - Path(yaml_data_set._filepath.as_posix()).unlink() - versioned_yaml_data_set.save(dummy_data) - assert versioned_yaml_data_set.exists() diff --git a/tests/extras/logging/__init__.py b/tests/extras/logging/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/extras/logging/test_color_logger.py b/tests/extras/logging/test_color_logger.py deleted file mode 100644 index 36c5cc30dc..0000000000 --- a/tests/extras/logging/test_color_logger.py +++ /dev/null @@ -1,16 +0,0 @@ -import logging - -from kedro.extras.logging import ColorHandler - - -def test_color_logger(caplog): - log = logging.getLogger(__name__) - for handler in log.handlers: - log.removeHandler(handler) # pragma: no cover - - log.addHandler(ColorHandler()) - log.info("Test") - - for record in caplog.records: - assert record.levelname == "INFO" - assert "Test" in record.msg diff --git a/tests/framework/cli/micropkg/test_micropkg_package.py b/tests/framework/cli/micropkg/test_micropkg_package.py index 7c0674c6e1..c27e6e8105 100644 --- a/tests/framework/cli/micropkg/test_micropkg_package.py +++ b/tests/framework/cli/micropkg/test_micropkg_package.py @@ -492,10 +492,9 @@ def test_micropkg_package_nested_module( "chdir_to_dummy_project", "cleanup_dist", "cleanup_pyproject_toml" ) class TestMicropkgPackageFromManifest: - def test_micropkg_package_all( # pylint: disable=too-many-locals + def test_micropkg_package_all( self, fake_repo_path, fake_project_cli, fake_metadata, tmp_path, mocker ): - # pylint: disable=import-outside-toplevel from kedro.framework.cli import micropkg spy = mocker.spy(micropkg, "_package_micropkg") @@ -535,7 +534,6 @@ def test_micropkg_package_all( # pylint: disable=too-many-locals def test_micropkg_package_all_empty_toml( self, fake_repo_path, fake_project_cli, fake_metadata, mocker ): - # pylint: disable=import-outside-toplevel from kedro.framework.cli import micropkg spy = mocker.spy(micropkg, "_package_micropkg") diff --git a/tests/framework/cli/micropkg/test_micropkg_pull.py b/tests/framework/cli/micropkg/test_micropkg_pull.py index 3e40843449..13754d9503 100644 --- a/tests/framework/cli/micropkg/test_micropkg_pull.py +++ b/tests/framework/cli/micropkg/test_micropkg_pull.py @@ -74,7 +74,6 @@ def test_pull_local_sdist( fake_metadata, ): """Test for pulling a valid sdist file locally.""" - # pylint: disable=too-many-locals call_pipeline_create(fake_project_cli, fake_metadata) call_micropkg_package(fake_project_cli, fake_metadata) call_pipeline_delete(fake_project_cli, fake_metadata) @@ -147,7 +146,6 @@ def test_pull_local_sdist_compare( into another location and check that unpacked files are identical to the ones in the original modular pipeline. """ - # pylint: disable=too-many-locals pipeline_name = "another_pipeline" call_pipeline_create(fake_project_cli, fake_metadata) call_micropkg_package(fake_project_cli, fake_metadata, alias=pipeline_name) @@ -291,7 +289,7 @@ def test_micropkg_pull_nested_destination( expected_test_files = {"__init__.py", "test_pipeline.py"} assert actual_test_files == expected_test_files - def test_micropkg_alias_refactors_imports( # pylint: disable=too-many-locals + def test_micropkg_alias_refactors_imports( self, fake_project_cli, fake_package_path, fake_repo_path, fake_metadata ): call_pipeline_create(fake_project_cli, fake_metadata) @@ -432,7 +430,6 @@ def test_pull_tests_missing( """Test for pulling a valid sdist file locally, but `tests` directory is missing from the sdist file. """ - # pylint: disable=too-many-locals call_pipeline_create(fake_project_cli, fake_metadata) test_path = fake_repo_path / "src" / "tests" / "pipelines" / PIPELINE_NAME shutil.rmtree(test_path) @@ -495,7 +492,6 @@ def test_pull_config_missing( Test for pulling a valid sdist file locally, but `config` directory is missing from the sdist file. """ - # pylint: disable=too-many-locals call_pipeline_create(fake_project_cli, fake_metadata) source_params_config = ( fake_repo_path @@ -560,7 +556,6 @@ def test_pull_from_pypi( """ Test for pulling a valid sdist file from pypi. """ - # pylint: disable=too-many-locals call_pipeline_create(fake_project_cli, fake_metadata) # We mock the `pip download` call, and manually create a package sdist file # to simulate the pypi scenario instead @@ -592,7 +587,7 @@ def test_pull_from_pypi( # Mock needed to avoid an error when build.util.project_wheel_metadata # calls tempfile.TemporaryDirectory, which is mocked class _FakeWheelMetadata: - def get_all(self, name, failobj=None): # pylint: disable=unused-argument + def get_all(self, name, failobj=None): return [] mocker.patch( @@ -854,10 +849,9 @@ def test_path_traversal( "chdir_to_dummy_project", "cleanup_dist", "cleanup_pyproject_toml" ) class TestMicropkgPullFromManifest: - def test_micropkg_pull_all( # pylint: disable=too-many-locals + def test_micropkg_pull_all( self, fake_repo_path, fake_project_cli, fake_metadata, mocker ): - # pylint: disable=import-outside-toplevel, line-too-long from kedro.framework.cli import micropkg spy = mocker.spy(micropkg, "_pull_package") @@ -897,7 +891,6 @@ def test_micropkg_pull_all( # pylint: disable=too-many-locals def test_micropkg_pull_all_empty_toml( self, fake_repo_path, fake_project_cli, fake_metadata, mocker ): - # pylint: disable=import-outside-toplevel from kedro.framework.cli import micropkg spy = mocker.spy(micropkg, "_pull_package") diff --git a/tests/framework/cli/pipeline/test_pipeline.py b/tests/framework/cli/pipeline/test_pipeline.py index 2426a352af..0feb28256d 100644 --- a/tests/framework/cli/pipeline/test_pipeline.py +++ b/tests/framework/cli/pipeline/test_pipeline.py @@ -4,9 +4,9 @@ import pytest import yaml from click.testing import CliRunner +from kedro_datasets.pandas import CSVDataSet from pandas import DataFrame -from kedro.extras.datasets.pandas import CSVDataSet from kedro.framework.cli.pipeline import _sync_dirs from kedro.framework.project import settings from kedro.framework.session import KedroSession @@ -46,7 +46,7 @@ def make_pipelines(request, fake_repo_path, fake_package_path, mocker): @pytest.mark.usefixtures("chdir_to_dummy_project") class TestPipelineCreateCommand: @pytest.mark.parametrize("env", [None, "local"]) - def test_create_pipeline( # pylint: disable=too-many-locals + def test_create_pipeline( self, fake_repo_path, fake_project_cli, fake_metadata, env, fake_package_path ): """Test creation of a pipeline""" @@ -80,7 +80,7 @@ def test_create_pipeline( # pylint: disable=too-many-locals assert actual_files == expected_files @pytest.mark.parametrize("env", [None, "local"]) - def test_create_pipeline_template( # pylint: disable=too-many-locals + def test_create_pipeline_template( self, fake_repo_path, fake_project_cli, @@ -112,7 +112,7 @@ def test_create_pipeline_template( # pylint: disable=too-many-locals assert result.exit_code == 0 @pytest.mark.parametrize("env", [None, "local"]) - def test_create_pipeline_template_command_line_override( # pylint: disable=too-many-locals + def test_create_pipeline_template_command_line_override( self, fake_repo_path, fake_project_cli, @@ -171,7 +171,7 @@ def test_create_pipeline_skip_config( test_dir = fake_repo_path / "src" / "tests" / "pipelines" / PIPELINE_NAME assert test_dir.is_dir() - def test_catalog_and_params( # pylint: disable=too-many-locals + def test_catalog_and_params( self, fake_repo_path, fake_project_cli, fake_metadata, fake_package_path ): """Test that catalog and parameter configs generated in pipeline diff --git a/tests/framework/cli/test_catalog.py b/tests/framework/cli/test_catalog.py index a0dba38635..5ad36b2391 100644 --- a/tests/framework/cli/test_catalog.py +++ b/tests/framework/cli/test_catalog.py @@ -1,8 +1,8 @@ import pytest import yaml from click.testing import CliRunner +from kedro_datasets.pandas import CSVDataSet -from kedro.extras.datasets.pandas import CSVDataSet from kedro.io import DataCatalog, MemoryDataset from kedro.pipeline import node from kedro.pipeline.modular_pipeline import pipeline as modular_pipeline @@ -400,7 +400,7 @@ def test_bad_env(self, fake_project_cli, fake_metadata): result = CliRunner().invoke(fake_project_cli, cmd, obj=fake_metadata) assert result.exit_code - assert "Unable to instantiate Kedro session" in result.output + assert "Unable to instantiate Kedro Catalog" in result.output @pytest.mark.usefixtures( diff --git a/tests/framework/cli/test_cli.py b/tests/framework/cli/test_cli.py index 8c33f4e2ae..288272b5dc 100644 --- a/tests/framework/cli/test_cli.py +++ b/tests/framework/cli/test_cli.py @@ -1,9 +1,7 @@ -# pylint: disable=too-many-lines from collections import namedtuple from itertools import cycle from os import rename from pathlib import Path -from unittest.mock import patch import anyconfig import click @@ -43,17 +41,17 @@ def stub_command(): @forward_command(stub_cli, name="forwarded_command") -def forwarded_command(args, **kwargs): # pylint: disable=unused-argument +def forwarded_command(args, **kwargs): print("fred", args) @forward_command(stub_cli, name="forwarded_help", forward_help=True) -def forwarded_help(args, **kwargs): # pylint: disable=unused-argument +def forwarded_help(args, **kwargs): print("fred", args) @forward_command(stub_cli) -def unnamed(args, **kwargs): # pylint: disable=unused-argument +def unnamed(args, **kwargs): print("fred", args) @@ -121,18 +119,6 @@ def test_help(self): assert result.exit_code == 0 assert "-h, --help Show this message and exit." in result.output - @patch("webbrowser.open") - def test_docs(self, patched_browser): - """Check that `kedro docs` opens a correct file in the browser.""" - result = CliRunner().invoke(cli, ["docs"]) - - assert result.exit_code == 0 - expected = f"https://kedro.readthedocs.io/en/{version}" - - assert patched_browser.call_count == 1 - args, _ = patched_browser.call_args - assert expected in args[0] - class TestCommandCollection: def test_found(self): diff --git a/tests/framework/cli/test_cli_hooks.py b/tests/framework/cli/test_cli_hooks.py index 41fbdaa705..75ecf9a309 100644 --- a/tests/framework/cli/test_cli_hooks.py +++ b/tests/framework/cli/test_cli_hooks.py @@ -74,7 +74,7 @@ def fake_plugin_distribution(mocker): version="0.1", ) mocker.patch( - "pluggy._manager.importlib_metadata.distributions", + "pluggy._manager.importlib.metadata.distributions", return_value=[fake_distribution], ) return fake_distribution @@ -93,7 +93,7 @@ def test_kedro_cli_should_invoke_cli_hooks_from_plugin( mocker, fake_metadata, fake_plugin_distribution, - entry_points, # pylint: disable=unused-argument + entry_points, ): caplog.set_level(logging.DEBUG, logger="kedro") diff --git a/tests/framework/cli/test_jupyter.py b/tests/framework/cli/test_jupyter.py index 8f363bac3e..20a5dc9ad0 100644 --- a/tests/framework/cli/test_jupyter.py +++ b/tests/framework/cli/test_jupyter.py @@ -1,7 +1,5 @@ -import json import shutil from pathlib import Path -from tempfile import NamedTemporaryFile import pytest from click.testing import CliRunner @@ -11,7 +9,7 @@ get_kernel_spec, ) -from kedro.framework.cli.jupyter import _create_kernel, _export_nodes +from kedro.framework.cli.jupyter import _create_kernel from kedro.framework.cli.utils import KedroCliError @@ -194,304 +192,3 @@ def cleanup_nodes_dir(fake_package_path): nodes_dir = fake_package_path / "nodes" if nodes_dir.exists(): shutil.rmtree(str(nodes_dir)) - - -@pytest.mark.usefixtures("chdir_to_dummy_project", "cleanup_nodes_dir") -class TestConvertNotebookCommand: - @pytest.fixture - def fake_export_nodes(self, mocker): - return mocker.patch("kedro.framework.cli.jupyter._export_nodes") - - @pytest.fixture - def tmp_file_path(self): - with NamedTemporaryFile() as f: - yield Path(f.name) - - # noqa: too-many-arguments - def test_convert_one_file_overwrite( - self, - mocker, - fake_project_cli, - fake_export_nodes, - tmp_file_path, - fake_package_path, - fake_metadata, - ): - """ - Trying to convert one file, the output file already exists, - overwriting it. - """ - mocker.patch.object(Path, "is_file", return_value=True) - mocker.patch("click.confirm", return_value=True) - output_dir = fake_package_path / "nodes" - assert not output_dir.exists() - - result = CliRunner().invoke( - fake_project_cli, - ["jupyter", "convert", str(tmp_file_path)], - obj=fake_metadata, - ) - assert not result.exit_code, result.stdout - - assert (output_dir / "__init__.py").is_file() - fake_export_nodes.assert_called_once_with( - tmp_file_path.resolve(), output_dir / f"{tmp_file_path.stem}.py" - ) - - def test_convert_one_file_do_not_overwrite( - self, mocker, fake_project_cli, fake_export_nodes, tmp_file_path, fake_metadata - ): - """ - Trying to convert one file, the output file already exists, - user refuses to overwrite it. - """ - mocker.patch.object(Path, "is_file", return_value=True) - mocker.patch("click.confirm", return_value=False) - - result = CliRunner().invoke( - fake_project_cli, - ["jupyter", "convert", str(tmp_file_path)], - obj=fake_metadata, - ) - assert not result.exit_code, result.stdout - - fake_export_nodes.assert_not_called() - - def test_convert_all_files( - self, - mocker, - fake_project_cli, - fake_export_nodes, - fake_package_path, - fake_metadata, - ): - """Trying to convert all files, the output files already exist.""" - mocker.patch.object(Path, "is_file", return_value=True) - mocker.patch("click.confirm", return_value=True) - mocker.patch( - "kedro.framework.cli.jupyter.iglob", return_value=["/path/1", "/path/2"] - ) - output_dir = fake_package_path / "nodes" - assert not output_dir.exists() - - result = CliRunner().invoke( - fake_project_cli, ["jupyter", "convert", "--all"], obj=fake_metadata - ) - assert not result.exit_code, result.stdout - - assert (output_dir / "__init__.py").is_file() - fake_export_nodes.assert_has_calls( - [ - mocker.call(Path("/path/1"), output_dir / "1.py"), - mocker.call(Path("/path/2"), output_dir / "2.py"), - ] - ) - - def test_convert_without_filepath_and_all_flag( - self, fake_project_cli, fake_metadata - ): - """Neither path nor --all flag is provided.""" - result = CliRunner().invoke( - fake_project_cli, ["jupyter", "convert"], obj=fake_metadata - ) - expected_output = ( - "Please specify a notebook filepath or " - "add '--all' to convert all notebooks.\n" - ) - assert result.exit_code - assert expected_output in result.stdout - - def test_non_unique_notebook_names_error( - self, fake_project_cli, mocker, fake_metadata - ): - """Trying to convert notebooks with the same name.""" - mocker.patch( - "kedro.framework.cli.jupyter.iglob", return_value=["/path1/1", "/path2/1"] - ) - - result = CliRunner().invoke( - fake_project_cli, ["jupyter", "convert", "--all"], obj=fake_metadata - ) - - expected_output = ( - "Error: Found non-unique notebook names! Please rename the following: 1\n" - ) - assert result.exit_code - assert expected_output in result.output - - def test_convert_one_file( - self, - fake_project_cli, - fake_export_nodes, - tmp_file_path, - fake_package_path, - fake_metadata, - ): - """Trying to convert one file, the output file doesn't exist.""" - output_dir = fake_package_path / "nodes" - assert not output_dir.exists() - - result = CliRunner().invoke( - fake_project_cli, - ["jupyter", "convert", str(tmp_file_path)], - obj=fake_metadata, - ) - assert not result.exit_code, result.stdout - - assert (output_dir / "__init__.py").is_file() - fake_export_nodes.assert_called_once_with( - tmp_file_path.resolve(), output_dir / f"{tmp_file_path.stem}.py" - ) - - def test_convert_one_file_nodes_directory_exists( - self, - fake_project_cli, - fake_export_nodes, - tmp_file_path, - fake_package_path, - fake_metadata, - ): - """User-created nodes/ directory is used as is.""" - output_dir = fake_package_path / "nodes" - assert not output_dir.exists() - output_dir.mkdir() - - result = CliRunner().invoke( - fake_project_cli, - ["jupyter", "convert", str(tmp_file_path)], - obj=fake_metadata, - ) - assert not result.exit_code, result.stdout - - assert not (output_dir / "__init__.py").is_file() - fake_export_nodes.assert_called_once_with( - tmp_file_path.resolve(), output_dir / f"{tmp_file_path.stem}.py" - ) - - -class TestExportNodes: - @pytest.fixture - def project_path(self, tmp_path): - temp = Path(str(tmp_path)) - return Path(temp / "some/path/to/my_project") - - @pytest.fixture - def nodes_path(self, project_path): - path = project_path / "src/my_project/nodes" - path.mkdir(parents=True) - return path - - def test_export_nodes(self, project_path, nodes_path): - nodes = json.dumps( - { - "cells": [ - { - "cell_type": "code", - "source": "print('hello world')", - "metadata": {"tags": ["node"]}, - }, - { - "cell_type": "code", - "source": "print(10+5)", - "metadata": {"tags": ["node"]}, - }, - {"cell_type": "code", "source": "a = 10", "metadata": {}}, - ] - } - ) - notebook_file = project_path / "notebook.ipynb" - notebook_file.write_text(nodes) - - output_path = nodes_path / f"{notebook_file.stem}.py" - _export_nodes(notebook_file, output_path) - - assert output_path.is_file() - assert output_path.read_text() == "print('hello world')\nprint(10+5)\n" - - def test_export_nodes_different_notebook_paths(self, project_path, nodes_path): - nodes = json.dumps( - { - "cells": [ - { - "cell_type": "code", - "source": "print('hello world')", - "metadata": {"tags": ["node"]}, - } - ] - } - ) - notebook_file1 = project_path / "notebook1.ipynb" - notebook_file1.write_text(nodes) - output_path1 = nodes_path / "notebook1.py" - - notebook_file2 = nodes_path / "notebook2.ipynb" - notebook_file2.write_text(nodes) - output_path2 = nodes_path / "notebook2.py" - - _export_nodes(notebook_file1, output_path1) - _export_nodes(notebook_file2, output_path2) - - assert output_path1.read_text() == "print('hello world')\n" - assert output_path2.read_text() == "print('hello world')\n" - - def test_export_nodes_nothing_to_write(self, project_path, nodes_path): - nodes = json.dumps( - { - "cells": [ - { - "cell_type": "code", - "source": "print('hello world')", - "metadata": {}, - }, - { - "cell_type": "text", - "source": "hello world", - "metadata": {"tags": ["node"]}, - }, - ] - } - ) - notebook_file = project_path / "notebook.iypnb" - notebook_file.write_text(nodes) - - with pytest.warns(UserWarning, match="Skipping notebook"): - output_path = nodes_path / f"{notebook_file.stem}.py" - _export_nodes(notebook_file, output_path) - - output_path = nodes_path / "notebook.py" - assert not output_path.exists() - - def test_export_nodes_overwrite(self, project_path, nodes_path): - existing_nodes = nodes_path / "notebook.py" - existing_nodes.touch() - existing_nodes.write_text("original") - - nodes = json.dumps( - { - "cells": [ - { - "cell_type": "code", - "source": "print('hello world')", - "metadata": {"tags": ["node"]}, - } - ] - } - ) - notebook_file = project_path / "notebook.iypnb" - notebook_file.write_text(nodes) - - output_path = nodes_path / f"{notebook_file.stem}.py" - _export_nodes(notebook_file, output_path) - - assert output_path.is_file() - assert output_path.read_text() == "print('hello world')\n" - - def test_export_nodes_json_error(self, nodes_path): - random_file = nodes_path / "notebook.txt" - random_file.touch() - random_file.write_text("original") - output_path = nodes_path / f"{random_file.stem}.py" - - pattern = "Provided filepath is not a Jupyter notebook" - with pytest.raises(KedroCliError, match=pattern): - _export_nodes(random_file, output_path) diff --git a/tests/framework/cli/test_project.py b/tests/framework/cli/test_project.py index d965113ea8..071bc6640d 100644 --- a/tests/framework/cli/test_project.py +++ b/tests/framework/cli/test_project.py @@ -1,234 +1,19 @@ -# pylint: disable=unused-argument import sys -from pathlib import Path import pytest from click.testing import CliRunner -from kedro.framework.cli.project import NO_DEPENDENCY_MESSAGE - @pytest.fixture(autouse=True) def call_mock(mocker): return mocker.patch("kedro.framework.cli.project.call") -@pytest.fixture(autouse=True) -def python_call_mock(mocker): - return mocker.patch("kedro.framework.cli.project.python_call") - - @pytest.fixture def fake_copyfile(mocker): return mocker.patch("shutil.copyfile") -@pytest.mark.usefixtures("chdir_to_dummy_project") -class TestActivateNbstripoutCommand: - @staticmethod - @pytest.fixture() - def fake_nbstripout(): - """ - ``nbstripout`` tries to access ``sys.stdin.buffer.readable`` - on import, but it's patches by pytest. - Let's replace it by the fake! - """ - sys.modules["nbstripout"] = "fake" - yield - del sys.modules["nbstripout"] - - @staticmethod - @pytest.fixture - def fake_git_repo(mocker): - return mocker.patch("subprocess.run", return_value=mocker.Mock(returncode=0)) - - @staticmethod - @pytest.fixture - def without_git_repo(mocker): - return mocker.patch("subprocess.run", return_value=mocker.Mock(returncode=1)) - - def test_install_successfully( - self, fake_project_cli, call_mock, fake_nbstripout, fake_git_repo, fake_metadata - ): - result = CliRunner().invoke( - fake_project_cli, ["activate-nbstripout"], obj=fake_metadata - ) - assert not result.exit_code - - call_mock.assert_called_once_with(["nbstripout", "--install"]) - - fake_git_repo.assert_called_once_with( - ["git", "rev-parse", "--git-dir"], capture_output=True - ) - - def test_nbstripout_not_installed( - self, fake_project_cli, fake_git_repo, mocker, fake_metadata - ): - """ - Run activate-nbstripout target without nbstripout installed - There should be a clear message about it. - """ - mocker.patch.dict("sys.modules", {"nbstripout": None}) - - result = CliRunner().invoke( - fake_project_cli, ["activate-nbstripout"], obj=fake_metadata - ) - assert result.exit_code - assert "nbstripout is not installed" in result.stdout - - def test_no_git_repo( - self, fake_project_cli, fake_nbstripout, without_git_repo, fake_metadata - ): - """ - Run activate-nbstripout target with no git repo available. - There should be a clear message about it. - """ - result = CliRunner().invoke( - fake_project_cli, ["activate-nbstripout"], obj=fake_metadata - ) - - assert result.exit_code - assert "Not a git repository" in result.stdout - - def test_no_git_executable( - self, fake_project_cli, fake_nbstripout, mocker, fake_metadata - ): - mocker.patch("subprocess.run", side_effect=FileNotFoundError) - result = CliRunner().invoke( - fake_project_cli, ["activate-nbstripout"], obj=fake_metadata - ) - - assert result.exit_code - assert "Git executable not found. Install Git first." in result.stdout - - -@pytest.mark.usefixtures("chdir_to_dummy_project") -class TestTestCommand: - def test_happy_path(self, fake_project_cli, python_call_mock): - result = CliRunner().invoke(fake_project_cli, ["test", "--random-arg", "value"]) - assert not result.exit_code - python_call_mock.assert_called_once_with("pytest", ("--random-arg", "value")) - - def test_pytest_not_installed( - self, fake_project_cli, python_call_mock, mocker, fake_repo_path, fake_metadata - ): - mocker.patch.dict("sys.modules", {"pytest": None}) - - result = CliRunner().invoke( - fake_project_cli, ["test", "--random-arg", "value"], obj=fake_metadata - ) - expected_message = NO_DEPENDENCY_MESSAGE.format( - module="pytest", src=str(fake_repo_path / "src") - ) - - assert result.exit_code - assert expected_message in result.stdout - python_call_mock.assert_not_called() - - -@pytest.mark.usefixtures("chdir_to_dummy_project") -class TestLintCommand: - @pytest.mark.parametrize("files", [(), ("src",)]) - def test_lint( - self, - fake_project_cli, - python_call_mock, - files, - mocker, - fake_repo_path, - fake_metadata, - ): - mocker.patch("kedro.framework.cli.project._check_module_importable") - result = CliRunner().invoke( - fake_project_cli, ["lint", *files], obj=fake_metadata - ) - assert not result.exit_code, result.stdout - - expected_files = files or ( - str(fake_repo_path / "src/tests"), - str(fake_repo_path / "src/dummy_package"), - ) - expected_calls = [ - mocker.call("black", expected_files), - mocker.call("flake8", expected_files), - mocker.call("isort", expected_files), - ] - - assert python_call_mock.call_args_list == expected_calls - - @pytest.mark.parametrize( - "check_flag,files", - [ - ("-c", ()), - ("--check-only", ()), - ("-c", ("src",)), - ("--check-only", ("src",)), - ], - ) - def test_lint_check_only( - self, - fake_project_cli, - python_call_mock, - check_flag, - mocker, - files, - fake_repo_path, - fake_metadata, - ): - mocker.patch("kedro.framework.cli.project._check_module_importable") - result = CliRunner().invoke( - fake_project_cli, ["lint", check_flag, *files], obj=fake_metadata - ) - assert not result.exit_code, result.stdout - - expected_files = files or ( - str(fake_repo_path / "src/tests"), - str(fake_repo_path / "src/dummy_package"), - ) - expected_calls = [ - mocker.call("black", ("--check",) + expected_files), - mocker.call("flake8", expected_files), - mocker.call("isort", ("--check",) + expected_files), - ] - - assert python_call_mock.call_args_list == expected_calls - - @pytest.mark.parametrize( - "module_name,side_effects", - [("flake8", [ImportError, None, None]), ("isort", [None, ImportError, None])], - ) - def test_import_not_installed( - self, - fake_project_cli, - python_call_mock, - module_name, - side_effects, - mocker, - fake_repo_path, - fake_metadata, - ): - # pretending we have the other linting dependencies, but not the - mocker.patch( - "kedro.framework.cli.utils.import_module", side_effect=side_effects - ) - - result = CliRunner().invoke(fake_project_cli, ["lint"], obj=fake_metadata) - expected_message = NO_DEPENDENCY_MESSAGE.format( - module=module_name, src=str(fake_repo_path / "src") - ) - - assert result.exit_code, result.stdout - assert expected_message in result.stdout - python_call_mock.assert_not_called() - - def test_pythonpath_env_var( - self, fake_project_cli, mocker, fake_repo_path, fake_metadata - ): - mocked_environ = mocker.patch("os.environ", {}) - CliRunner().invoke(fake_project_cli, ["lint"], obj=fake_metadata) - assert mocked_environ == {"PYTHONPATH": str(fake_repo_path / "src")} - - @pytest.mark.usefixtures("chdir_to_dummy_project") class TestIpythonCommand: def test_happy_path( @@ -313,161 +98,3 @@ def test_happy_path( ), ] ) - - -@pytest.mark.usefixtures("chdir_to_dummy_project") -class TestBuildDocsCommand: - def test_happy_path( - self, - call_mock, - python_call_mock, - fake_project_cli, - mocker, - fake_repo_path, - fake_metadata, - ): - fake_rmtree = mocker.patch("shutil.rmtree") - - result = CliRunner().invoke(fake_project_cli, ["build-docs"], obj=fake_metadata) - assert not result.exit_code, result.stdout - call_mock.assert_has_calls( - [ - mocker.call( - [ - "sphinx-apidoc", - "--module-first", - "-o", - "docs/source", - str(fake_repo_path / "src/dummy_package"), - ] - ), - mocker.call( - ["sphinx-build", "-M", "html", "docs/source", "docs/build", "-a"] - ), - ] - ) - python_call_mock.assert_has_calls( - [ - mocker.call("pip", ["install", str(fake_repo_path / "src/[docs]")]), - mocker.call( - "pip", - ["install", "-r", str(fake_repo_path / "src/requirements.txt")], - ), - mocker.call("ipykernel", ["install", "--user", "--name=dummy_package"]), - ] - ) - fake_rmtree.assert_called_once_with("docs/build", ignore_errors=True) - - @pytest.mark.parametrize("open_flag", ["-o", "--open"]) - def test_open_docs(self, open_flag, fake_project_cli, mocker, fake_metadata): - mocker.patch("shutil.rmtree") - patched_browser = mocker.patch("webbrowser.open") - result = CliRunner().invoke( - fake_project_cli, ["build-docs", open_flag], obj=fake_metadata - ) - assert not result.exit_code, result.stdout - expected_path = (Path.cwd() / "docs" / "build" / "html" / "index.html").as_uri() - patched_browser.assert_called_once_with(expected_path) - - -@pytest.mark.usefixtures("chdir_to_dummy_project", "fake_copyfile") -class TestBuildReqsCommand: - def test_compile_from_requirements_file( - self, - python_call_mock, - fake_project_cli, - mocker, - fake_repo_path, - fake_copyfile, - fake_metadata, - ): - # File exists: - mocker.patch.object(Path, "is_file", return_value=True) - - result = CliRunner().invoke(fake_project_cli, ["build-reqs"], obj=fake_metadata) - assert not result.exit_code, result.stdout - assert "Requirements built!" in result.stdout - - python_call_mock.assert_called_once_with( - "piptools", - [ - "compile", - str(fake_repo_path / "src" / "requirements.txt"), - "--output-file", - str(fake_repo_path / "src" / "requirements.lock"), - ], - ) - - def test_compile_from_input_and_to_output_file( - self, - python_call_mock, - fake_project_cli, - fake_repo_path, - fake_copyfile, - fake_metadata, - ): - # File exists: - input_file = fake_repo_path / "src" / "dev-requirements.txt" - with open(input_file, "a", encoding="utf-8") as file: - file.write("") - output_file = fake_repo_path / "src" / "dev-requirements.lock" - - result = CliRunner().invoke( - fake_project_cli, - [ - "build-reqs", - "--input-file", - str(input_file), - "--output-file", - str(output_file), - ], - obj=fake_metadata, - ) - assert not result.exit_code, result.stdout - assert "Requirements built!" in result.stdout - python_call_mock.assert_called_once_with( - "piptools", - ["compile", str(input_file), "--output-file", str(output_file)], - ) - - @pytest.mark.parametrize( - "extra_args", [["--generate-hashes"], ["-foo", "--bar", "baz"]] - ) - def test_extra_args( - self, - python_call_mock, - fake_project_cli, - fake_repo_path, - extra_args, - fake_metadata, - ): - requirements_txt = fake_repo_path / "src" / "requirements.txt" - - result = CliRunner().invoke( - fake_project_cli, ["build-reqs"] + extra_args, obj=fake_metadata - ) - - assert not result.exit_code, result.stdout - assert "Requirements built!" in result.stdout - - call_args = ( - ["compile"] - + extra_args - + [str(requirements_txt)] - + ["--output-file", str(fake_repo_path / "src" / "requirements.lock")] - ) - python_call_mock.assert_called_once_with("piptools", call_args) - - @pytest.mark.parametrize("os_name", ["posix", "nt"]) - def test_missing_requirements_txt( - self, fake_project_cli, mocker, fake_metadata, os_name, fake_repo_path - ): - """Test error when input file requirements.txt doesn't exists.""" - requirements_txt = fake_repo_path / "src" / "requirements.txt" - - mocker.patch("kedro.framework.cli.project.os").name = os_name - mocker.patch.object(Path, "is_file", return_value=False) - result = CliRunner().invoke(fake_project_cli, ["build-reqs"], obj=fake_metadata) - assert result.exit_code # Error expected - assert isinstance(result.exception, FileNotFoundError) - assert f"File '{requirements_txt}' not found" in str(result.exception) diff --git a/tests/framework/cli/test_starters.py b/tests/framework/cli/test_starters.py index 644e67d592..b8119c3268 100644 --- a/tests/framework/cli/test_starters.py +++ b/tests/framework/cli/test_starters.py @@ -17,7 +17,7 @@ KedroStarterSpec, ) -FILES_IN_TEMPLATE = 30 +FILES_IN_TEMPLATE = 28 @pytest.fixture diff --git a/tests/framework/conftest.py b/tests/framework/conftest.py new file mode 100644 index 0000000000..3923d9f559 --- /dev/null +++ b/tests/framework/conftest.py @@ -0,0 +1,26 @@ +import pytest + +from kedro.framework.project import configure_logging + + +@pytest.fixture +def default_logging_config(): + logging_config = { + "version": 1, + "disable_existing_loggers": False, + "handlers": { + "rich": {"class": "kedro.logging.RichHandler", "rich_tracebacks": True} + }, + "loggers": {"kedro": {"level": "INFO"}}, + "root": {"handlers": ["rich"]}, + } + return logging_config + + +@pytest.fixture(autouse=True) +def reset_logging(request, default_logging_config): + yield + if "nologreset" in request.keywords: + return + + configure_logging(default_logging_config) diff --git a/tests/framework/context/test_context.py b/tests/framework/context/test_context.py index bc4dac8ad2..ef79ac9f54 100644 --- a/tests/framework/context/test_context.py +++ b/tests/framework/context/test_context.py @@ -12,7 +12,8 @@ import pytest import toml import yaml -from pandas.util.testing import assert_frame_equal +from attrs.exceptions import FrozenInstanceError +from pandas.testing import assert_frame_equal from kedro import __version__ as kedro_version from kedro.config import ConfigLoader, MissingConfigException @@ -33,7 +34,7 @@ MOCK_PACKAGE_NAME = "mock_package_name" -class BadCatalog: # pylint: disable=too-few-public-methods +class BadCatalog: """ Catalog class that doesn't subclass `DataCatalog`, for testing only. """ @@ -187,9 +188,7 @@ def extra_params(request): @pytest.fixture -def dummy_context( - tmp_path, prepare_project_dir, env, extra_params -): # pylint: disable=unused-argument +def dummy_context(tmp_path, prepare_project_dir, env, extra_params): configure_project(MOCK_PACKAGE_NAME) config_loader = ConfigLoader(str(tmp_path / "conf"), env=env) context = KedroContext( @@ -210,6 +209,10 @@ def test_attributes(self, tmp_path, dummy_context): assert isinstance(dummy_context.project_path, Path) assert dummy_context.project_path == tmp_path.resolve() + def test_immutable_instance(self, dummy_context): + with pytest.raises(FrozenInstanceError): + dummy_context.catalog = 1 + def test_get_catalog_always_using_absolute_path(self, dummy_context): config_loader = dummy_context.config_loader conf_catalog = config_loader.get("catalog*") @@ -223,8 +226,7 @@ def test_get_catalog_always_using_absolute_path(self, dummy_context): ds_path = catalog._data_sets["horses"]._filepath assert PurePath(ds_path.as_posix()).is_absolute() assert ( - ds_path.as_posix() - == (dummy_context._project_path / "horses.csv").as_posix() + ds_path.as_posix() == (dummy_context.project_path / "horses.csv").as_posix() ) def test_get_catalog_validates_transcoded_datasets(self, dummy_context, mocker): diff --git a/tests/framework/project/test_logging.py b/tests/framework/project/test_logging.py index 52e7d5b4c1..3ce6020d65 100644 --- a/tests/framework/project/test_logging.py +++ b/tests/framework/project/test_logging.py @@ -1,4 +1,3 @@ -# pylint: disable=import-outside-toplevel import logging import sys from pathlib import Path @@ -6,35 +5,36 @@ import pytest import yaml -from kedro.framework.project import LOGGING, configure_logging +from kedro.framework.project import LOGGING, configure_logging, configure_project @pytest.fixture -def default_logging_config(): +def default_logging_config_with_project(): logging_config = { "version": 1, "disable_existing_loggers": False, "handlers": { "rich": {"class": "kedro.logging.RichHandler", "rich_tracebacks": True} }, - "loggers": {"kedro": {"level": "INFO"}}, + "loggers": {"kedro": {"level": "INFO"}, "test_project": {"level": "INFO"}}, "root": {"handlers": ["rich"]}, } return logging_config -@pytest.fixture(autouse=True) -def reset_logging(default_logging_config): - yield - configure_logging(default_logging_config) - - def test_default_logging_config(default_logging_config): assert LOGGING.data == default_logging_config assert "rich" in {handler.name for handler in logging.getLogger().handlers} assert logging.getLogger("kedro").level == logging.INFO +def test_project_logging_in_default_logging_config(default_logging_config_with_project): + configure_project("test_project") + assert LOGGING.data == default_logging_config_with_project + assert logging.getLogger("kedro").level == logging.INFO + assert logging.getLogger("test_project").level == logging.INFO + + def test_environment_variable_logging_config(monkeypatch, tmp_path): config_path = Path(tmp_path) / "logging.yml" monkeypatch.setenv("KEDRO_LOGGING_CONFIG", config_path.absolute()) diff --git a/tests/framework/project/test_pipeline_registry.py b/tests/framework/project/test_pipeline_registry.py index b210de009d..a03dc9aad5 100644 --- a/tests/framework/project/test_pipeline_registry.py +++ b/tests/framework/project/test_pipeline_registry.py @@ -26,12 +26,11 @@ def register_pipelines(): def test_pipelines_without_configure_project_is_empty( - mock_package_name_with_pipelines_file, # pylint: disable=unused-argument + mock_package_name_with_pipelines_file, ): # Reimport `pipelines` from `kedro.framework.project` to ensure that # it was not set by a pior call to the `configure_project` function. del sys.modules["kedro.framework.project"] - # pylint: disable=reimported, import-outside-toplevel from kedro.framework.project import pipelines assert pipelines == {} diff --git a/tests/framework/project/test_settings.py b/tests/framework/project/test_settings.py index 65774e0e37..779288fd2f 100644 --- a/tests/framework/project/test_settings.py +++ b/tests/framework/project/test_settings.py @@ -19,7 +19,7 @@ class MyDataCatalog(DataCatalog): pass -class ProjectHooks: # pylint: disable=too-few-public-methods +class ProjectHooks: pass diff --git a/tests/framework/session/conftest.py b/tests/framework/session/conftest.py index 1ac0de6301..c38a363666 100644 --- a/tests/framework/session/conftest.py +++ b/tests/framework/session/conftest.py @@ -122,7 +122,7 @@ def mock_pipeline() -> Pipeline: ) -class LogRecorder(logging.Handler): # pylint: disable=abstract-method +class LogRecorder(logging.Handler): """Record logs received from a process-safe log listener""" def __init__(self): @@ -370,9 +370,7 @@ class MockSettings(_ProjectSettings): @pytest.fixture -def mock_session( - mock_settings, mock_package_name, tmp_path -): # pylint: disable=unused-argument +def mock_session(mock_settings, mock_package_name, tmp_path): configure_project(mock_package_name) session = KedroSession.create( mock_package_name, tmp_path, extra_params={"params:key": "value"} diff --git a/tests/framework/session/test_session.py b/tests/framework/session/test_session.py index c3cfb2bf7b..448540ce88 100644 --- a/tests/framework/session/test_session.py +++ b/tests/framework/session/test_session.py @@ -15,6 +15,7 @@ from kedro.framework.cli.utils import _split_params from kedro.framework.context import KedroContext from kedro.framework.project import ( + LOGGING, ValidationError, Validator, _HasSharedParentClassValidator, @@ -30,13 +31,13 @@ _FAKE_PIPELINE_NAME = "fake_pipeline" -class BadStore: # pylint: disable=too-few-public-methods +class BadStore: """ Store class that doesn't subclass `BaseSessionStore`, for testing only. """ -class BadConfigLoader: # pylint: disable=too-few-public-methods +class BadConfigLoader: """ ConfigLoader class that doesn't subclass `AbstractConfigLoader`, for testing only. """ @@ -256,7 +257,6 @@ def test_create( ): mock_click_ctx = mocker.patch("click.get_current_context").return_value mocker.patch("sys.argv", ["kedro", "run", "--params=x"]) - mocker.patch("kedro.framework.session.KedroSession._setup_logging") session = KedroSession.create( mock_package_name, fake_project, env=env, extra_params=extra_params ) @@ -328,7 +328,7 @@ def test_load_context_with_envvar( self, fake_project, monkeypatch, mock_package_name, mocker ): mocker.patch("kedro.config.config.ConfigLoader.get") - mocker.patch("kedro.framework.session.KedroSession._setup_logging") + monkeypatch.setenv("KEDRO_ENV", "my_fake_env") session = KedroSession.create(mock_package_name, fake_project) @@ -343,7 +343,6 @@ def test_load_config_loader_with_envvar( self, fake_project, monkeypatch, mock_package_name, mocker ): mocker.patch("kedro.config.config.ConfigLoader.get") - mocker.patch("kedro.framework.session.KedroSession._setup_logging") monkeypatch.setenv("KEDRO_ENV", "my_fake_env") session = KedroSession.create(mock_package_name, fake_project) @@ -401,23 +400,16 @@ def test_broken_config_loader(self, mock_settings_file_bad_config_loader_class): with pytest.raises(ValidationError, match=re.escape(pattern)): assert mock_settings.CONFIG_LOADER_CLASS - def test_no_logging_config(self, fake_project, caplog, mock_package_name, mocker): + def test_logging_is_not_reconfigure( + self, fake_project, caplog, mock_package_name, mocker + ): caplog.set_level(logging.DEBUG, logger="kedro") - mocker.patch("subprocess.check_output") + mock_logging = mocker.patch.object(LOGGING, "configure") session = KedroSession.create(mock_package_name, fake_project) session.close() - expected_log_messages = [ - "No project logging configuration loaded; " - "Kedro's default logging configuration will be used." - ] - actual_log_messages = [ - rec.getMessage() - for rec in caplog.records - if rec.name == SESSION_LOGGER_NAME and rec.levelno == logging.DEBUG - ] - assert actual_log_messages == expected_log_messages + mock_logging.assert_not_called() @pytest.mark.usefixtures("mock_settings_context_class") def test_default_store( @@ -539,7 +531,6 @@ def test_git_describe_error( """ caplog.set_level(logging.DEBUG, logger="kedro") - mocker.patch("kedro.framework.session.KedroSession._setup_logging") mocker.patch("subprocess.check_output", side_effect=exception) session = KedroSession.create(mock_package_name, fake_project) assert "git" not in session.store @@ -559,7 +550,6 @@ def test_get_username_error(self, fake_project, mock_package_name, mocker, caplo caplog.set_level(logging.DEBUG, logger="kedro") mocker.patch("subprocess.check_output") - mocker.patch("kedro.framework.session.KedroSession._setup_logging") mocker.patch("getpass.getuser", side_effect=FakeException("getuser error")) session = KedroSession.create(mock_package_name, fake_project) assert "username" not in session.store @@ -655,7 +645,7 @@ def test_run( @pytest.mark.usefixtures("mock_settings_context_class") @pytest.mark.parametrize("fake_pipeline_name", [None, _FAKE_PIPELINE_NAME]) - def test_run_multiple_times( # pylint: disable=too-many-locals + def test_run_multiple_times( self, fake_project, fake_session_id, @@ -740,7 +730,7 @@ def test_run_non_existent_pipeline( @pytest.mark.usefixtures("mock_settings_context_class") @pytest.mark.parametrize("fake_pipeline_name", [None, _FAKE_PIPELINE_NAME]) - def test_run_exception( # pylint: disable=too-many-locals + def test_run_exception( self, fake_project, fake_session_id, @@ -806,7 +796,7 @@ def test_run_exception( # pylint: disable=too-many-locals @pytest.mark.usefixtures("mock_settings_context_class") @pytest.mark.parametrize("fake_pipeline_name", [None, _FAKE_PIPELINE_NAME]) - def test_run_broken_pipeline_multiple_times( # pylint: disable=too-many-locals + def test_run_broken_pipeline_multiple_times( self, fake_project, fake_session_id, @@ -922,40 +912,6 @@ def fake_project_with_logging_file_handler(fake_project): return fake_project -@pytest.mark.usefixtures("mock_settings") -def test_setup_logging_using_absolute_path( - fake_project_with_logging_file_handler, mocker, mock_package_name -): - mocked_logging = mocker.patch("logging.config.dictConfig") - KedroSession.create(mock_package_name, fake_project_with_logging_file_handler) - - mocked_logging.assert_called_once() - call_args = mocked_logging.call_args[0][0] - - expected_log_filepath = ( - fake_project_with_logging_file_handler / "logs" / "info.log" - ).as_posix() - actual_log_filepath = call_args["handlers"]["info_file_handler"]["filename"] - assert actual_log_filepath == expected_log_filepath - - -@pytest.mark.usefixtures("mock_settings_omega_config_loader_class") -def test_setup_logging_using_omega_config_loader_class( - fake_project_with_logging_file_handler, mocker, mock_package_name -): - mocked_logging = mocker.patch("logging.config.dictConfig") - KedroSession.create(mock_package_name, fake_project_with_logging_file_handler) - - mocked_logging.assert_called_once() - call_args = mocked_logging.call_args[0][0] - - expected_log_filepath = ( - fake_project_with_logging_file_handler / "logs" / "info.log" - ).as_posix() - actual_log_filepath = call_args["handlers"]["info_file_handler"]["filename"] - assert actual_log_filepath == expected_log_filepath - - def get_all_values(mapping: Mapping): for value in mapping.values(): yield value diff --git a/tests/framework/session/test_session_extension_hooks.py b/tests/framework/session/test_session_extension_hooks.py index 0e966c7aa3..3f407852b9 100644 --- a/tests/framework/session/test_session_extension_hooks.py +++ b/tests/framework/session/test_session_extension_hooks.py @@ -419,7 +419,7 @@ def test_before_and_after_dataset_saved_hooks_parallel_runner( assert record.data.to_dict() == dummy_dataframe.to_dict() -class MockDatasetReplacement: # pylint: disable=too-few-public-methods +class MockDatasetReplacement: pass diff --git a/tests/framework/session/test_session_hook_manager.py b/tests/framework/session/test_session_hook_manager.py index 32592a0827..7d67b4d05c 100644 --- a/tests/framework/session/test_session_hook_manager.py +++ b/tests/framework/session/test_session_hook_manager.py @@ -36,11 +36,13 @@ class MockSettings(_ProjectSettings): class TestSessionHookManager: """Test the process of registering hooks with the hook manager in a session.""" + @pytest.mark.nologreset def test_assert_register_hooks(self, project_hooks, mock_session): hook_manager = mock_session._hook_manager assert hook_manager.is_registered(project_hooks) @pytest.mark.usefixtures("mock_session") + @pytest.mark.nologreset def test_calling_register_hooks_twice(self, project_hooks, mock_session): """Calling hook registration multiple times should not raise""" hook_manager = mock_session._hook_manager @@ -51,6 +53,7 @@ def test_calling_register_hooks_twice(self, project_hooks, mock_session): assert hook_manager.is_registered(project_hooks) @pytest.mark.parametrize("num_plugins", [0, 1]) + @pytest.mark.nologreset def test_hooks_registered_when_session_created( self, mocker, request, caplog, project_hooks, num_plugins ): @@ -81,6 +84,7 @@ def test_hooks_registered_when_session_created( assert expected_msg in log_messages @pytest.mark.usefixtures("mock_settings_with_disabled_hooks") + @pytest.mark.nologreset def test_disabling_auto_discovered_hooks( self, mocker, diff --git a/tests/io/test_cached_dataset.py b/tests/io/test_cached_dataset.py index 2d8145318a..f665021dd3 100644 --- a/tests/io/test_cached_dataset.py +++ b/tests/io/test_cached_dataset.py @@ -3,15 +3,15 @@ import pytest import yaml +from kedro_datasets.pandas import CSVDataSet -from kedro.extras.datasets.pandas import CSVDataSet from kedro.io import CachedDataset, DataCatalog, DatasetError, MemoryDataset YML_CONFIG = """ test_ds: type: CachedDataset dataset: - type: kedro.extras.datasets.pandas.CSVDataSet + type: kedro_datasets.pandas.CSVDataSet filepath: example.csv """ @@ -20,7 +20,7 @@ type: CachedDataset versioned: true dataset: - type: kedro.extras.datasets.pandas.CSVDataSet + type: kedro_datasets.pandas.CSVDataSet filepath: example.csv """ @@ -28,7 +28,7 @@ test_ds: type: CachedDataset dataset: - type: kedro.extras.datasets.pandas.CSVDataSet + type: kedro_datasets.pandas.CSVDataSet filepath: example.csv versioned: true """ @@ -60,10 +60,10 @@ def test_save_load_caching(self, mocker): cached_ds.save(42) assert cached_ds.load() == 42 - assert wrapped.load.call_count == 0 # pylint: disable=no-member - assert wrapped.save.call_count == 1 # pylint: disable=no-member - assert cached_ds._cache.load.call_count == 1 # pylint: disable=no-member - assert cached_ds._cache.save.call_count == 1 # pylint: disable=no-member + assert wrapped.load.call_count == 0 + assert wrapped.save.call_count == 1 + assert cached_ds._cache.load.call_count == 1 + assert cached_ds._cache.save.call_count == 1 def test_load_empty_cache(self, mocker): wrapped = MemoryDataset(-42) @@ -73,8 +73,8 @@ def test_load_empty_cache(self, mocker): mocker.spy(cached_ds._cache, "load") assert cached_ds.load() == -42 - assert wrapped.load.call_count == 1 # pylint: disable=no-member - assert cached_ds._cache.load.call_count == 0 # pylint: disable=no-member + assert wrapped.load.call_count == 1 + assert cached_ds._cache.load.call_count == 0 def test_from_yaml(self, mocker): config = yaml.safe_load(StringIO(YML_CONFIG)) diff --git a/tests/io/test_data_catalog.py b/tests/io/test_data_catalog.py index f4ac13974f..729d76d64a 100644 --- a/tests/io/test_data_catalog.py +++ b/tests/io/test_data_catalog.py @@ -7,9 +7,9 @@ import pandas as pd import pytest -from pandas.util.testing import assert_frame_equal +from kedro_datasets.pandas import CSVDataSet, ParquetDataSet +from pandas.testing import assert_frame_equal -from kedro.extras.datasets.pandas import CSVDataSet, ParquetDataSet from kedro.io import ( AbstractDataset, DataCatalog, @@ -201,7 +201,7 @@ def conflicting_feed_dict(): class BadDataset(AbstractDataset): # pragma: no cover def __init__(self, filepath): self.filepath = filepath - raise Exception("Naughty!") # pylint: disable=broad-exception-raised + raise Exception("Naughty!") def _load(self): return None @@ -405,7 +405,11 @@ def test_confirm(self, mocker, caplog): data_catalog.confirm("mocked") mock_ds.confirm.assert_called_once_with() assert caplog.record_tuples == [ - ("kedro.io.data_catalog", logging.INFO, "Confirming dataset 'mocked'") + ( + "kedro.io.data_catalog", + logging.INFO, + "Confirming dataset 'mocked'", + ) ] @pytest.mark.parametrize( @@ -465,10 +469,10 @@ def test_config_relative_import(self, sane_config): DataCatalog.from_config(**sane_config) def test_config_import_kedro_datasets(self, sane_config, mocker): - """Test kedro.extras.datasets default path to the dataset class""" + """Test kedro_datasets default path to the dataset class""" # Spy _load_obj because kedro_datasets is not installed and we can't import it. - import kedro.io.core # pylint: disable=import-outside-toplevel + import kedro.io.core spy = mocker.spy(kedro.io.core, "_load_obj") parse_dataset_definition(sane_config["catalog"]["boats"]) @@ -479,7 +483,7 @@ def test_config_import_kedro_datasets(self, sane_config, mocker): assert call_args[0][0] == f"{prefix}pandas.CSVDataSet" def test_config_import_extras(self, sane_config): - """Test kedro.extras.datasets default path to the dataset class""" + """Test kedro_datasets default path to the dataset class""" sane_config["catalog"]["boats"]["type"] = "pandas.CSVDataSet" assert DataCatalog.from_config(**sane_config) @@ -527,7 +531,7 @@ def test_missing_credentials(self, sane_config): def test_link_credentials(self, sane_config, mocker): """Test credentials being linked to the relevant data set""" - mock_client = mocker.patch("kedro.extras.datasets.pandas.csv_dataset.fsspec") + mock_client = mocker.patch("kedro_datasets.pandas.csv_dataset.fsspec") config = deepcopy(sane_config) del config["catalog"]["boats"] @@ -537,7 +541,7 @@ def test_link_credentials(self, sane_config, mocker): mock_client.filesystem.assert_called_with("s3", **expected_client_kwargs) def test_nested_credentials(self, sane_config_with_nested_creds, mocker): - mock_client = mocker.patch("kedro.extras.datasets.pandas.csv_dataset.fsspec") + mock_client = mocker.patch("kedro_datasets.pandas.csv_dataset.fsspec") config = deepcopy(sane_config_with_nested_creds) del config["catalog"]["boats"] DataCatalog.from_config(**config) @@ -565,11 +569,10 @@ def test_missing_dependency(self, sane_config, mocker): """Test that dependency is missing.""" pattern = "dependency issue" - # pylint: disable=unused-argument,inconsistent-return-statements def dummy_load(obj_path, *args, **kwargs): - if obj_path == "kedro.extras.datasets.pandas.CSVDataSet": + if obj_path == "kedro_datasets.pandas.CSVDataSet": raise AttributeError(pattern) - if obj_path == "kedro.extras.datasets.pandas.__all__": + if obj_path == "kedro_datasets.pandas.__all__": return ["CSVDataSet"] mocker.patch("kedro.io.core.load_obj", side_effect=dummy_load) @@ -657,7 +660,7 @@ def test_from_sane_config_versioned(self, sane_config, dummy_dataframe): # Verify that `VERSION_FORMAT` can help regenerate `current_ts`. actual_timestamp = datetime.strptime( - catalog.datasets.boats.resolve_load_version(), # pylint: disable=no-member + catalog.datasets.boats.resolve_load_version(), VERSION_FORMAT, ) expected_timestamp = current_ts.replace( @@ -702,11 +705,11 @@ def test_compare_tracking_and_other_dataset_versioned( # Verify that saved version on tracking dataset is the same as on the CSV dataset csv_timestamp = datetime.strptime( - catalog.datasets.boats.resolve_save_version(), # pylint: disable=no-member + catalog.datasets.boats.resolve_save_version(), VERSION_FORMAT, ) tracking_timestamp = datetime.strptime( - catalog.datasets.planes.resolve_save_version(), # pylint: disable=no-member + catalog.datasets.planes.resolve_save_version(), VERSION_FORMAT, ) @@ -915,7 +918,7 @@ def test_factory_config_versioned( # Verify that `VERSION_FORMAT` can help regenerate `current_ts`. actual_timestamp = datetime.strptime( - catalog.datasets.tesla_cars.resolve_load_version(), # pylint: disable=no-member + catalog.datasets.tesla_cars.resolve_load_version(), VERSION_FORMAT, ) expected_timestamp = current_ts.replace( diff --git a/tests/io/test_incremental_dataset.py b/tests/io/test_incremental_dataset.py index 76218b6324..db9421e886 100644 --- a/tests/io/test_incremental_dataset.py +++ b/tests/io/test_incremental_dataset.py @@ -8,15 +8,15 @@ import boto3 import pandas as pd import pytest +from kedro_datasets.pickle import PickleDataSet +from kedro_datasets.text import TextDataSet from moto import mock_s3 -from pandas.util.testing import assert_frame_equal +from pandas.testing import assert_frame_equal -from kedro.extras.datasets.pickle import PickleDataSet -from kedro.extras.datasets.text import TextDataSet from kedro.io import AbstractDataset, DatasetError, IncrementalDataset from kedro.io.data_catalog import CREDENTIALS_KEY -DATASET = "kedro.extras.datasets.pandas.CSVDataSet" +DATASET = "kedro_datasets.pandas.CSVDataSet" @pytest.fixture @@ -227,7 +227,7 @@ def test_checkpoint_path(self, local_csvs, partitioned_data_pandas): "checkpoint_config,expected_checkpoint_class", [ (None, TextDataSet), - ({"type": "kedro.extras.datasets.pickle.PickleDataSet"}, PickleDataSet), + ({"type": "kedro_datasets.pickle.PickleDataSet"}, PickleDataSet), ({"type": "tests.io.test_incremental_dataset.DummyDataset"}, DummyDataset), ], ) diff --git a/tests/io/test_memory_dataset.py b/tests/io/test_memory_dataset.py index 81d1c3fc38..3ff032b71e 100644 --- a/tests/io/test_memory_dataset.py +++ b/tests/io/test_memory_dataset.py @@ -1,6 +1,5 @@ import re -# pylint: disable=unused-argument import numpy as np import pandas as pd import pytest @@ -218,7 +217,7 @@ def test_infer_mode_deepcopy(data): def test_infer_mode_assign(): - class DataFrame: # pylint: disable=too-few-public-methods + class DataFrame: pass data = DataFrame() diff --git a/tests/io/test_partitioned_dataset.py b/tests/io/test_partitioned_dataset.py index 453ff1781e..067c28e7da 100644 --- a/tests/io/test_partitioned_dataset.py +++ b/tests/io/test_partitioned_dataset.py @@ -7,10 +7,10 @@ import pandas as pd import pytest import s3fs +from kedro_datasets.pandas import CSVDataSet, ParquetDataSet from moto import mock_s3 -from pandas.util.testing import assert_frame_equal +from pandas.testing import assert_frame_equal -from kedro.extras.datasets.pandas import CSVDataSet, ParquetDataSet from kedro.io import DatasetError, PartitionedDataset from kedro.io.data_catalog import CREDENTIALS_KEY from kedro.io.partitioned_dataset import KEY_PROPAGATION_WARNING @@ -39,14 +39,14 @@ def local_csvs(tmp_path, partitioned_data_pandas): LOCAL_DATASET_DEFINITION = [ "pandas.CSVDataSet", - "kedro.extras.datasets.pandas.CSVDataSet", + "kedro_datasets.pandas.CSVDataSet", CSVDataSet, {"type": "pandas.CSVDataSet", "save_args": {"index": False}}, {"type": CSVDataSet}, ] -class FakeDataset: # pylint: disable=too-few-public-methods +class FakeDataset: pass @@ -404,7 +404,7 @@ def test_dataset_creds(self, pds_config, expected_ds_creds, global_creds): BUCKET_NAME = "fake_bucket_name" S3_DATASET_DEFINITION = [ "pandas.CSVDataSet", - "kedro.extras.datasets.pandas.CSVDataSet", + "kedro_datasets.pandas.CSVDataSet", CSVDataSet, {"type": "pandas.CSVDataSet", "save_args": {"index": False}}, {"type": CSVDataSet}, diff --git a/tests/ipython/test_ipython.py b/tests/ipython/test_ipython.py index 2ff8b5c416..95a23283f8 100644 --- a/tests/ipython/test_ipython.py +++ b/tests/ipython/test_ipython.py @@ -1,4 +1,3 @@ -# pylint: disable=import-outside-toplevel from pathlib import Path import pytest @@ -18,7 +17,7 @@ @pytest.fixture(autouse=True) def cleanup_pipeline(): yield - from kedro.framework.project import pipelines # pylint: disable=reimported + from kedro.framework.project import pipelines pipelines.configure() @@ -165,9 +164,6 @@ class TestLoadIPythonExtension: def test_load_ipython_extension(self, ipython): ipython.magic("load_ext kedro.ipython") - def test_load_ipython_extension_old_location(self, ipython): - ipython.magic("load_ext kedro.ipython") - def test_load_extension_missing_dependency(self, mocker): mocker.patch("kedro.ipython.reload_kedro", side_effect=ImportError) mocker.patch( @@ -249,7 +245,6 @@ def test_only_path_specified(self): assert result == expected def test_only_local_namespace_specified(self): - # pylint: disable=too-few-public-methods class MockKedroContext: # A dummy stand-in for KedroContext sufficient for this test _project_path = Path("/test").resolve() @@ -283,7 +278,6 @@ def test_project_path_unresolvable_warning(self, mocker, caplog, ipython): assert expected_message in log_messages def test_project_path_update(self, caplog): - # pylint: disable=too-few-public-methods class MockKedroContext: # A dummy stand-in for KedroContext sufficient for this test _project_path = Path("/test").resolve() diff --git a/tests/pipeline/test_node.py b/tests/pipeline/test_node.py index 14e4782ea2..3fdfdbaecb 100644 --- a/tests/pipeline/test_node.py +++ b/tests/pipeline/test_node.py @@ -187,7 +187,7 @@ def test_node_invalid_less_than(self): pattern = "'<' not supported between instances of 'Node' and 'str'" with pytest.raises(TypeError, match=pattern): - n < "hello" # pylint: disable=pointless-statement + n < "hello" def test_different_input_list_order_not_equal(self): first = node(biconcat, ["input1", "input2"], "output1", name="A") @@ -266,13 +266,13 @@ def duplicate_output_list_node(): ( duplicate_output_dict_node, r"Failed to create node identity" - r"\(\[A\]\) -> \[A,A\] due to " + r"\(\[A\]\) -> \[A;A\] due to " r"duplicate output\(s\) {\'A\'}.", ), ( duplicate_output_list_node, r"Failed to create node identity" - r"\(\[A\]\) -> \[A,A\] due to " + r"\(\[A\]\) -> \[A;A\] due to " r"duplicate output\(s\) {\'A\'}.", ), ], @@ -300,7 +300,7 @@ def dummy_func_args(**kwargs): return dummy_func_args, "A", "B" -lambda_identity = lambda input1: input1 # noqa: disable=E731 # pylint: disable=C3001 +lambda_identity = lambda input1: input1 # noqa: disable=E731 def lambda_inconsistent_input_size(): diff --git a/tests/pipeline/test_node_run.py b/tests/pipeline/test_node_run.py index 40289890ef..458e011f6e 100644 --- a/tests/pipeline/test_node_run.py +++ b/tests/pipeline/test_node_run.py @@ -1,5 +1,3 @@ -# pylint: disable=unused-argument - import pytest from kedro.io import LambdaDataset diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 075557d148..780595dc69 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -383,7 +383,7 @@ def test_bad_combine_node(self): fred = node(identity, "input", "output") pipeline = modular_pipeline([fred]) with pytest.raises(TypeError): - pipeline + fred # pylint: disable=pointless-statement + pipeline + fred def test_bad_combine_int(self): """int cannot be combined to pipeline, tests __radd__""" @@ -405,7 +405,7 @@ def test_conflicting_names(self): "appear more than once:\n\nFree nodes:\n - a" ) with pytest.raises(ValueError, match=re.escape(pattern)): - pipeline1 + new_pipeline # pylint: disable=pointless-statement + pipeline1 + new_pipeline def test_conflicting_outputs(self): """Node outputs must be unique.""" @@ -416,7 +416,7 @@ def test_conflicting_outputs(self): [node(biconcat, ["input", "input2"], ["output", "output2"], name="b")] ) with pytest.raises(OutputNotUniqueError, match=r"\['output'\]"): - pipeline1 + new_pipeline # pylint: disable=pointless-statement + pipeline1 + new_pipeline def test_duplicate_node_confirms(self): """Test that non-unique dataset confirms break pipeline concatenation""" @@ -427,7 +427,7 @@ def test_duplicate_node_confirms(self): [node(identity, "input2", "output2", confirms=["other", "output2"])] ) with pytest.raises(ConfirmNotUniqueError, match=r"\['other'\]"): - pipeline1 + pipeline2 # pylint: disable=pointless-statement + pipeline1 + pipeline2 class TestPipelineOperators: @@ -533,7 +533,7 @@ def test_invalid_remove(self): p = modular_pipeline([]) pattern = r"unsupported operand type\(s\) for -: 'Pipeline' and 'str'" with pytest.raises(TypeError, match=pattern): - p - "hello" # pylint: disable=pointless-statement + p - "hello" def test_combine_same_node(self): """Multiple (identical) pipelines are possible""" @@ -567,7 +567,7 @@ def test_invalid_intersection(self): p = modular_pipeline([]) pattern = r"unsupported operand type\(s\) for &: 'Pipeline' and 'str'" with pytest.raises(TypeError, match=pattern): - p & "hello" # pylint: disable=pointless-statement + p & "hello" def test_union(self): pipeline1 = modular_pipeline( @@ -588,7 +588,7 @@ def test_invalid_union(self): p = modular_pipeline([]) pattern = r"unsupported operand type\(s\) for |: 'Pipeline' and 'str'" with pytest.raises(TypeError, match=pattern): - p | "hello" # pylint: disable=pointless-statement + p | "hello" def test_node_unique_confirms(self): """Test that unique dataset confirms don't break pipeline concatenation""" @@ -649,7 +649,7 @@ def test_full(self, str_node_inputs_list): "#### Pipeline execution order ####", "Inputs: input1, input2", "", - "node1: biconcat([input1,input2]) -> [input3]", + "node1: biconcat([input1;input2]) -> [input3]", "node2: identity([input3]) -> [input4]", "", "Outputs: input4", diff --git a/tests/runner/conftest.py b/tests/runner/conftest.py index 85ef9b9aa5..4c720a7a4a 100644 --- a/tests/runner/conftest.py +++ b/tests/runner/conftest.py @@ -15,7 +15,7 @@ def identity(arg): return arg -def sink(arg): # pylint: disable=unused-argument +def sink(arg): pass @@ -24,7 +24,7 @@ def fan_in(*args): def exception_fn(*args): - raise Exception("test exception") # pylint: disable=broad-exception-raised + raise Exception("test exception") def return_none(arg): @@ -32,7 +32,7 @@ def return_none(arg): return arg -def return_not_serialisable(arg): # pylint: disable=unused-argument +def return_not_serialisable(arg): return lambda x: x @@ -70,7 +70,6 @@ def persistent_dataset_catalog(): def _load(): return 0 - # pylint: disable=unused-argument def _save(arg): pass diff --git a/tests/test_utils.py b/tests/test_utils.py index 4e99f3f726..1ca93067df 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,7 +5,6 @@ from kedro.utils import load_obj -# pylint: disable=too-few-public-methods class DummyClass: pass diff --git a/tests/tools/test_cli.py b/tests/tools/test_cli.py index cf3ce71d1c..346d7af6e6 100644 --- a/tests/tools/test_cli.py +++ b/tests/tools/test_cli.py @@ -13,13 +13,9 @@ REPO_NAME = "cli_tools_dummy_project" PACKAGE_NAME = "cli_tools_dummy_package" DEFAULT_KEDRO_COMMANDS = [ - "activate-nbstripout", - "build-docs", - "build-reqs", "catalog", "ipython", "jupyter", - "lint", "new", "package", "pipeline", @@ -27,7 +23,6 @@ "registry", "run", "starter", - "test", ]