diff --git a/kedro-datasets/kedro_datasets/spark/__init__.py b/kedro-datasets/kedro_datasets/spark/__init__.py index bba4e3df0..db987959e 100644 --- a/kedro-datasets/kedro_datasets/spark/__init__.py +++ b/kedro-datasets/kedro_datasets/spark/__init__.py @@ -10,6 +10,7 @@ SparkHiveDataset: Any SparkJDBCDataset: Any SparkStreamingDataset: Any +GBQQueryDataset: Any __getattr__, __dir__, __all__ = lazy.attach( __name__, @@ -19,5 +20,6 @@ "spark_hive_dataset": ["SparkHiveDataset"], "spark_jdbc_dataset": ["SparkJDBCDataset"], "spark_streaming_dataset": ["SparkStreamingDataset"], + "spark_gbq_dataset": ["GBQQueryDataset"], }, ) diff --git a/kedro-datasets/kedro_datasets/spark/spark_gbq_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_gbq_dataset.py new file mode 100644 index 000000000..5379f3a00 --- /dev/null +++ b/kedro-datasets/kedro_datasets/spark/spark_gbq_dataset.py @@ -0,0 +1,240 @@ +"""``AbstractDataset`` implementation to access Spark dataframes using +``pyspark``. +""" + +from __future__ import annotations + +import base64 +import json +import logging +from copy import deepcopy +from typing import Any, NoReturn + +from kedro.io import AbstractDataset, DatasetError +from py4j.protocol import Py4JJavaError +from pyspark.sql import DataFrame + +from kedro_datasets._utils.spark_utils import get_spark +import copy +import fsspec + +from kedro.io.core import get_protocol_and_path + +logger = logging.getLogger(__name__) + + +class GBQQueryDataset(AbstractDataset[None, DataFrame]): + """``GBQQueryDataset`` loads data from Google BigQuery with a SQL query using BigQuery Spark connector. + + Example usage for the + `YAML API `_: + + .. code-block:: yaml + + my_gbq_spark_data: + type: spark.GBQQueryDataset + sql: | + SELECT * FROM your_table + materialization_dataset: your_dataset + materialization_project: your_project + bq_credentials: + file: /path/to/your/credentials.json + fs_credentials: + key: value + + Example usage for the + `Python API `_: + + .. code-block:: pycon + + >>> from kedro_datasets.spark import GBQQueryDataset + >>> import pyspark.sql as sql + >>> + >>> # Define your SQL query + >>> sql = "SELECT * FROM your_table" + >>> + >>> # Initialize dataset + >>> dataset = GBQQueryDataset( + ... sql=sql, + ... materialization_dataset="your_dataset", + ... materialization_project="your_project", # optional + ... bq_credentials=dict(file="/path/to/your/credentials.json"), # optional + ... fs_credentials=dict(key="value"), # optional + ... ) + >>> + >>> # Load data + >>> df = dataset.load() + >>> + >>> # Example output + >>> df.show() + """ + + _VALID_CREDENTIALS_KEYS = {"base64", "file", "json"} + + def __init__( # noqa: PLR0913 + self, + materialization_dataset: str, + sql: str | None = None, + filepath: str | None = None, + materialization_project: str | None = None, + load_args: dict[str, Any] | None = None, + fs_args: dict[str, Any] | None = None, + bq_credentials: dict[str, Any] | None = None, + fs_credentials: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + """Creates a new instance of ``SparkGBQDataSet`` pointing to a specific table in Google BigQuery. + + Args: + materialization_dataset: The name of the dataset to materialize the query results. + sql: The SQL query to execute + filepath: A path to a file with a sql query statement. + materialization_project: The name of the project to materialize the query results. + Optional (defaults to the project id set by the credentials). + load_args: Load args passed to Spark DataFrameReader load method. + It is dependent on the selected file format. You can find + a list of read options for each supported format + in Spark DataFrame read documentation: + https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html + bq_credentials: Credentials to authenticate spark session with google bigquery. + Dictionary with key specifying the type of credentials ('base64', 'file', 'json'). + Alternatively, you can pass the credentials in load_args as follows: + + When passing as `base64`: + `load_args={"credentials": "your_credentials"}` + + When passing as a `file`: + `load_args={"credentialsFile": "/path/to/your/credentials.json"}` + + When passing as a json object: + NOT SUPPORTED + + Read more here: + https://github.com/GoogleCloudDataproc/spark-bigquery-connector?tab=readme-ov-file#how-do-i-authenticate-outside-gce--dataproc + fs_credentials: Credentials to authenticate with the filesystem. + The keyword args would be directly passed to fsspec.filesystem constructor. + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. + """ + if sql and filepath: + raise DatasetError( + "'sql' and 'filepath' arguments cannot both be provided." + "Please only provide one." + ) + + if not (sql or filepath): + raise DatasetError( + "'sql' and 'filepath' arguments cannot both be empty." + "Please provide a sql query or path to a sql query file." + ) + + if sql: + self._sql = sql + self._filepath = None + else: + # TODO: Add protocol specific handling cases for different filesystems. + protocol, path = get_protocol_and_path(str(filepath)) + + self._fs_args = fs_args or {} + self._fs_credentials = fs_credentials or {} + self._fs_protocol = protocol + + self._fs = fsspec.filesystem( + self._protocol, **self._fs_credentials, **self._fs_args + ) + self._filepath = path + + self._materialization_dataset = materialization_dataset + self._materialization_project = materialization_project + self._load_args = load_args or {} + self._bq_credentials = bq_credentials or {} + + self._metadata = metadata + + def _get_spark_bq_credentials(self) -> dict[str, str]: + if not self._bq_credentials: + return {} + + if len(self._bq_credentials) > 1: + raise ValueError( + "Please provide only one of 'base64', 'file' or 'json' key in the credentials. " + f"You provided: {list(self._bq_credentials.keys())}" + ) + if self._bq_credentials.get("base64"): + return { + "credentials": self._bq_credentials["base64"], + } + if self._bq_credentials.get("file"): + return { + "credentialsFile": self._bq_credentials["file"], + } + if self._bq_credentials.get("json"): + creds_b64 = base64.b64encode( + json.dumps(self._bq_credentials["json"]).encode("utf-8") + ).decode("utf-8") + return {"credentials": creds_b64} + + raise ValueError( + f"Please provide one of 'base64', 'file' or 'json' key in the credentials. You provided: {list(self._bq_credentials.keys())[0]}" + ) + + def _load_sql_from_filepath(self) -> str: + with self._fs.open(self._filepath, "r") as f: + return f.read() + + def _get_sql(self) -> str: + if self._sql: + return self._sql + else: + return self._load_sql_from_filepath() + + def _get_spark_load_args(self) -> dict[str, Any]: + spark_load_args = deepcopy(self._load_args) + spark_load_args["query"] = self._get_sql() + spark_load_args["materializationDataset"] = self._materialization_dataset + + if self._materialization_project: + spark_load_args["materializationProject"] = self._materialization_project + + spark_load_args.update(self._get_spark_bq_credentials()) + + try: + views_enabled_spark_conf = get_spark().conf.get("viewsEnabled") + except Py4JJavaError: + views_enabled_spark_conf = "false" + + if views_enabled_spark_conf != "true": + spark_load_args["viewsEnabled"] = "true" + logger.warning( + "The 'viewsEnabled' configuration is not set to 'true' in the SparkSession. " + "This is required for the Spark BigQuery connector to read via a SQL query. " + "Setting 'viewsEnabled' to 'true' for the current query read operation. " + "This may incur additional costs!" + ) + + return spark_load_args + + def load(self) -> DataFrame: + """Loads data from Google BigQuery. + + Returns: + A Spark DataFrame. + """ + spark = get_spark() + read_obj = spark.read.format("bigquery") + + return read_obj.load(**self._get_spark_load_args()) + + def save(self, data: None) -> NoReturn: + raise DatasetError("'save' is not supported on GBQQueryDataset") + + def _describe(self) -> dict[str, Any]: + return { + "sql": self._sql, + "materialization_dataset": self._materialization_dataset, + "materialization_project": self._materialization_project, + "load_args": self._load_args, + "metadata": self._metadata, + } diff --git a/kedro-datasets/tests/spark/test_spark_gbq_dataset.py b/kedro-datasets/tests/spark/test_spark_gbq_dataset.py new file mode 100644 index 000000000..83a12841b --- /dev/null +++ b/kedro-datasets/tests/spark/test_spark_gbq_dataset.py @@ -0,0 +1,140 @@ +import pytest +from pyspark.sql import SparkSession +from kedro_datasets.spark.spark_gbq_dataset import GBQQueryDataset +import json +import base64 +from kedro.io import DatasetError +import re + + +SQL_QUERY = "SELECT * FROM table" +MATERIALIZATION_DATASET = "dataset" +MATERIALIZATION_PROJECT = "project" +LOAD_ARGS = {"key": "value"} +REQUIRED_INIT_ARGS = { + "sql": SQL_QUERY, + "materialization_dataset": MATERIALIZATION_DATASET, +} + + +@pytest.fixture +def spark_session(mocker): + return mocker.MagicMock(spec=SparkSession) + + +@pytest.fixture +def dummy_save_dataset(spark_session): + return spark_session.createDataFrame([("foo",)], ["bar"]) + + +@pytest.fixture +def gbq_query_dataset(): + return GBQQueryDataset( + sql=SQL_QUERY, + materialization_dataset=MATERIALIZATION_DATASET, + materialization_project=MATERIALIZATION_PROJECT, + load_args=LOAD_ARGS, + ) + + +def test_save_not_implemented(gbq_query_dataset, dummy_save_dataset): + with pytest.raises( + DatasetError, + match=r"'save' is not supported on GBQQueryDataset", + ): + gbq_query_dataset.save(dummy_save_dataset) + + +@pytest.mark.parametrize( + "credentials, expected_credentials", + [ + ({"base64": "base64_creds"}, {"credentials": "base64_creds"}), + ({"file": "/path/to/creds.json"}, {"credentialsFile": "/path/to/creds.json"}), + ( + {"json": {"type": "service_account"}}, + { + "credentials": base64.b64encode( + json.dumps({"type": "service_account"}).encode("utf-8") + ).decode("utf-8") + }, + ), + ({}, {}), + ], +) +def test_get_spark_bq_credentials(gbq_query_dataset, credentials, expected_credentials): + gbq_query_dataset._bq_credentials = credentials + assert gbq_query_dataset._get_spark_bq_credentials() == expected_credentials + + +def test_invalid_bq_credentials_key(gbq_query_dataset): + + invalid_cred_key = "invalid_cred_key" + gbq_query_dataset._bq_credentials = {invalid_cred_key: "value"} + with pytest.raises( + ValueError, + match=f"Please provide one of 'base64', 'file' or 'json' key in the credentials. You provided: {invalid_cred_key}", + ): + gbq_query_dataset._get_spark_bq_credentials() + + +@pytest.mark.parametrize( + "credentials", + [ + {"base64": "base64_creds", "file": "/path/to/creds.json"}, + {"base64": "base64_creds", "json": {"type": "service_account"}}, + {"file": "/path/to/creds.json", "json": {"type": "service_account"}}, + { + "base64": "base64_creds", + "file": "/path/to/creds.json", + "json": {"type": "service_account"}, + }, + {"base64": "base64_creds", "invalid_key": "value"}, + ], +) +def test_more_than_one_bq_credentials_key(gbq_query_dataset, credentials): + gbq_query_dataset._bq_credentials = credentials + pattern = re.escape( + f"Please provide only one of 'base64', 'file' or 'json' key in the credentials. You provided: {list(credentials.keys())}" + ) + with pytest.raises( + ValueError, + match=pattern, + ): + gbq_query_dataset._get_spark_bq_credentials() + + +@pytest.mark.parametrize( + "init_args, expected_load_args", + [ + ( + REQUIRED_INIT_ARGS, + { + "query": REQUIRED_INIT_ARGS["sql"], + "materializationDataset": REQUIRED_INIT_ARGS["materialization_dataset"], + "viewsEnabled": "true", + }, + ), + ( + {**REQUIRED_INIT_ARGS, "materialization_project": MATERIALIZATION_PROJECT}, + { + "query": REQUIRED_INIT_ARGS["sql"], + "materializationDataset": REQUIRED_INIT_ARGS["materialization_dataset"], + "materializationProject": MATERIALIZATION_PROJECT, + "viewsEnabled": "true", + }, + ), + ], +) +def test_load(mocker, spark_session, init_args, expected_load_args): + gbq_query_dataset = GBQQueryDataset(**init_args) + mocker.patch( + "kedro_datasets.spark.spark_gbq_dataset.get_spark", return_value=spark_session + ) + read_obj = mocker.MagicMock() + spark_session.read.format.return_value = read_obj + read_obj.load.return_value = mocker.MagicMock() + + gbq_query_dataset.load() + + spark_session.read.format.assert_called_once_with("bigquery") + read_obj.load.assert_called_once_with(**expected_load_args)