diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index 9c6deef45..317ddea3a 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -1,6 +1,9 @@ # Upcoming Release: +## Bug fixes and other changes +* Added a warning when the user tries to use `SparkDataSet` on Databricks without specifying a file path with the `/dbfs/` prefix. + # Release 1.0.2: ## Bug fixes and other changes diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_dataset.py index ca923c72e..d366eae08 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset.py @@ -2,6 +2,8 @@ ``pyspark`` """ import json +import logging +import os from copy import deepcopy from fnmatch import fnmatch from functools import partial @@ -23,6 +25,8 @@ from pyspark.sql.utils import AnalysisException from s3fs import S3FileSystem +logger = logging.getLogger(__name__) + def _parse_glob_pattern(pattern: str) -> str: special = ("*", "?", "[") @@ -114,6 +118,20 @@ def _dbfs_exists(pattern: str, dbutils: Any) -> bool: return False +def _deployed_on_databricks() -> bool: + """Check if running on Databricks.""" + return "DATABRICKS_RUNTIME_VERSION" in os.environ + + +def _path_has_dbfs_prefix(path: str) -> bool: + """Check if a file path has a valid dbfs prefix. + + Args: + path: File path to check. + """ + return path.startswith("/dbfs/") + + class KedroHdfsInsecureClient(InsecureClient): """Subclasses ``hdfs.InsecureClient`` and implements ``hdfs_exists`` and ``hdfs_glob`` methods required by ``SparkDataSet``""" @@ -240,9 +258,7 @@ def __init__( # pylint: disable=too-many-arguments Args: filepath: Filepath in POSIX format to a Spark dataframe. When using Databricks - and working with data written to mount path points, - specify ``filepath``s for (versioned) ``SparkDataSet``s - starting with ``/dbfs/mnt``. + specify ``filepath``s starting with ``/dbfs/``. file_format: File format used during load and save operations. These are formats supported by the running SparkContext include parquet, csv, delta. For a list of supported @@ -304,7 +320,12 @@ def __init__( # pylint: disable=too-many-arguments else: path = PurePosixPath(filepath) - + if _deployed_on_databricks() and not _path_has_dbfs_prefix(filepath): + logger.warning( + "Using SparkDataSet on Databricks without the `/dbfs/` prefix in the " + "filepath is a known source of error. You must add this prefix to %s", + filepath, + ) if filepath.startswith("/dbfs"): dbutils = _get_dbutils(self._get_spark()) if dbutils: diff --git a/kedro-datasets/tests/spark/test_spark_dataset.py b/kedro-datasets/tests/spark/test_spark_dataset.py index d02f99bff..74c5ee2bf 100644 --- a/kedro-datasets/tests/spark/test_spark_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_dataset.py @@ -1,3 +1,4 @@ +# pylint: disable=too-many-lines import re import sys import tempfile @@ -161,6 +162,7 @@ def isDir(self): return "." not in self.path.split("/")[-1] +# pylint: disable=too-many-public-methods class TestSparkDataSet: def test_load_parquet(self, tmp_path, sample_pandas_df): temp_path = (tmp_path / "data").as_posix() @@ -440,6 +442,34 @@ def test_copy(self): assert spark_dataset_copy._file_format == "csv" assert spark_dataset_copy._save_args == {"mode": "overwrite"} + def test_dbfs_prefix_warning_no_databricks(self, caplog): + # test that warning is not raised when not on Databricks + filepath = "my_project/data/02_intermediate/processed_data" + expected_message = ( + "Using SparkDataSet on Databricks without the `/dbfs/` prefix in the " + f"filepath is a known source of error. You must add this prefix to {filepath}." + ) + SparkDataSet(filepath="my_project/data/02_intermediate/processed_data") + assert expected_message not in caplog.text + + def test_dbfs_prefix_warning_on_databricks_with_prefix(self, monkeypatch, caplog): + # test that warning is not raised when on Databricks and filepath has /dbfs prefix + filepath = "/dbfs/my_project/data/02_intermediate/processed_data" + monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "7.3") + SparkDataSet(filepath=filepath) + assert caplog.text == "" + + def test_dbfs_prefix_warning_on_databricks_no_prefix(self, monkeypatch, caplog): + # test that warning is raised when on Databricks and filepath does not have /dbfs prefix + filepath = "my_project/data/02_intermediate/processed_data" + expected_message = ( + "Using SparkDataSet on Databricks without the `/dbfs/` prefix in the " + f"filepath is a known source of error. You must add this prefix to {filepath}" + ) + monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "7.3") + SparkDataSet(filepath=filepath) + assert expected_message in caplog.text + class TestSparkDataSetVersionedLocal: def test_no_version(self, versioned_dataset_local):