diff --git a/libs/foundry-dev-tools/src/foundry_dev_tools/clients/foundry_sql_server.py b/libs/foundry-dev-tools/src/foundry_dev_tools/clients/foundry_sql_server.py index 5a1cb70..2b78851 100644 --- a/libs/foundry-dev-tools/src/foundry_dev_tools/clients/foundry_sql_server.py +++ b/libs/foundry-dev-tools/src/foundry_dev_tools/clients/foundry_sql_server.py @@ -4,7 +4,7 @@ import time import warnings -from typing import TYPE_CHECKING, Literal, overload +from typing import TYPE_CHECKING, Any, Literal, overload from foundry_dev_tools.clients.api_client import APIClient from foundry_dev_tools.errors.handling import ErrorHandlingConfig @@ -278,3 +278,263 @@ def api_queries_results( }, **kwargs, ) + + +class FoundrySqlServerClientV2(APIClient): + """FoundrySqlServerClientV2 implements the newer foundry-sql-server API. + + This client uses a different API flow compared to V1: + - Executes queries via POST to /api/ with applicationId and sql + - Polls POST to /api/status for query completion + - Retrieves results via POST to /api/stream with tickets + """ + + api_name = "foundry-sql-server" + + @overload + def query_foundry_sql( + self, + query: str, + application_id: str, + return_type: Literal["pandas"], + disable_arrow_compression: bool = ..., + timeout: int = ..., + ) -> pd.core.frame.DataFrame: ... + + @overload + def query_foundry_sql( + self, + query: str, + application_id: str, + return_type: Literal["spark"], + disable_arrow_compression: bool = ..., + timeout: int = ..., + ) -> pyspark.sql.DataFrame: ... + + @overload + def query_foundry_sql( + self, + query: str, + application_id: str, + return_type: Literal["arrow"], + disable_arrow_compression: bool = ..., + timeout: int = ..., + ) -> pa.Table: ... + + @overload + def query_foundry_sql( + self, + query: str, + application_id: str, + return_type: SQLReturnType = ..., + disable_arrow_compression: bool = ..., + timeout: int = ..., + ) -> tuple[dict, list[list]] | pd.core.frame.DataFrame | pa.Table | pyspark.sql.DataFrame: ... + + def query_foundry_sql( + self, + query: str, + return_type: SQLReturnType = "pandas", + disable_arrow_compression: bool = False, + application_id: str | None = None, + ) -> tuple[dict, list[list]] | pd.core.frame.DataFrame | pa.Table | pyspark.sql.DataFrame: + """Queries the Foundry SQL server using the V2 API. + + Uses Arrow IPC to communicate with the Foundry SQL Server Endpoint. + + Example: + df = client.query_foundry_sql( + query="SELECT * FROM `ri.foundry.main.dataset.abc` LIMIT 10", + application_id="ri.foundry.main.dataset.abc" + ) + + Args: + query: The SQL Query + return_type: See :py:class:foundry_dev_tools.foundry_api_client.SQLReturnType + disable_arrow_compression: Whether to disable Arrow compression + application_id: The application/dataset RID, defaults to foundry-dev-tools User-Agent + + Returns: + :external+pandas:py:class:`~pandas.DataFrame` | :external+pyarrow:py:class:`~pyarrow.Table` | :external+spark:py:class:`~pyspark.sql.DataFrame`: + + A pandas DataFrame, Spark DataFrame or pyarrow.Table with the result. + + Raises: + FoundrySqlQueryFailedError: If the query fails + FoundrySqlQueryClientTimedOutError: If the query times out + + """ # noqa: E501 + # Execute the query + if not application_id: + application_id = self.context.client.headers["User-Agent"] + response_json = self.api_execute( + sql=query, + application_id=application_id, + disable_arrow_compression=disable_arrow_compression, + ).json() + + query_identifier = self._extract_query_identifier(response_json) + + # Poll for completion + while response_json.get("type") != "success": + time.sleep(0.2) + response = self.api_status(query_identifier) + response_json = response.json() + + if response_json.get("type") == "failed": + raise FoundrySqlQueryFailedError(response) + + # Extract tickets from successful response + tickets = self._extract_tickets(response_json) + + # Fetch Arrow data using tickets + arrow_stream_reader = self.read_stream_results_arrow(tickets) + + if return_type == "pandas": + return arrow_stream_reader.read_pandas() + + if return_type == "spark": + from foundry_dev_tools.utils.converter.foundry_spark import ( + arrow_stream_to_spark_dataframe, + ) + + return arrow_stream_to_spark_dataframe(arrow_stream_reader) + + if return_type == "arrow": + return arrow_stream_reader.read_all() + + raise ValueError("The following return_type is not supported: " + return_type) + + def _extract_query_identifier(self, response_json: dict[str, Any]) -> dict[str, Any]: + """Extract query identifier from execute response. + + Args: + response_json: Response JSON from execute API + + Returns: + Query identifier dict + + """ + if response_json.get("type") == "pending": + return response_json["pending"]["query"] + if response_json.get("type") == "success": + return response_json["success"]["query"] + msg = f"Unexpected response type: {response_json.get('type')}" + raise ValueError(msg) + + def _extract_tickets(self, response_json: dict[str, Any]) -> list[str]: + """Extract tickets from success response. + + Args: + response_json: Success response JSON from status API + + Returns: + List of tickets for fetching results + + """ + if response_json.get("type") != "success": + msg = f"Expected success response, got: {response_json.get('type')}" + + raise ValueError(msg) + + chunks = response_json["success"]["result"]["interactive"]["chunks"] + return [chunk["ticket"] for chunk in chunks] + + def read_stream_results_arrow(self, tickets: list[str]) -> pa.ipc.RecordBatchStreamReader: + """Fetch query results using tickets and return Arrow stream reader. + + Args: + tickets: List of tickets from status API success response + + Returns: + Arrow RecordBatchStreamReader + + """ + from foundry_dev_tools._optional.pyarrow import pa + + response = self._api_stream_ticket(tickets) + response.raw.decode_content = True + + return pa.ipc.RecordBatchStreamReader(response.raw) + + def api_execute( + self, + sql: str, + application_id: str, + disable_arrow_compression: bool = False, + **kwargs, + ) -> requests.Response: + """Execute a SQL query via the V2 API. + + Args: + sql: The SQL query to execute + application_id: The application/dataset RID + disable_arrow_compression: Whether to disable Arrow compression + **kwargs: gets passed to :py:meth:`APIClient.api_request` + + Returns: + Response with query execution status + + """ + return self.api_request( + "POST", + "", # Root endpoint /api/ + json={ + "applicationId": application_id, + "sql": sql, + "disableArrowCompression": disable_arrow_compression, + }, + **kwargs, + ) + + def api_status( + self, + query_identifier: dict[str, Any], + **kwargs, + ) -> requests.Response: + """Get the status of a SQL query via the V2 API. + + Args: + query_identifier: Query identifier dict (e.g., {"type": "interactive", "interactive": "query-id"}) + **kwargs: gets passed to :py:meth:`APIClient.api_request` + + Returns: + Response with query status + + """ + return self.api_request( + "POST", + "status", + json={ + "query": query_identifier, + }, + **kwargs, + ) + + def _api_stream_ticket( + self, + tickets: list[str], + **kwargs, + ) -> requests.Response: + """Fetch query results using tickets via the V2 API. + + Args: + tickets: List of tickets from status API success response + **kwargs: gets passed to :py:meth:`APIClient.api_request` + + Returns: + Response with Arrow-encoded query results + + """ + return self.api_request( + "POST", + "stream", + json={ + "tickets": tickets, + }, + headers={ + "Accept": "application/octet-stream", + }, + stream=True, + **kwargs, + ) diff --git a/libs/foundry-dev-tools/src/foundry_dev_tools/config/context.py b/libs/foundry-dev-tools/src/foundry_dev_tools/config/context.py index ebeaecf..0ade496 100644 --- a/libs/foundry-dev-tools/src/foundry_dev_tools/config/context.py +++ b/libs/foundry-dev-tools/src/foundry_dev_tools/config/context.py @@ -146,6 +146,11 @@ def foundry_sql_server(self) -> foundry_sql_server.FoundrySqlServerClient: """Returns :py:class:`foundry_dev_tools.clients.foundry_sql_server.FoundrySqlServerClient`.""" return foundry_sql_server.FoundrySqlServerClient(self) + @cached_property + def foundry_sql_server_v2(self) -> foundry_sql_server.FoundrySqlServerClientV2: + """Returns :py:class:`foundry_dev_tools.clients.foundry_sql_server.FoundrySqlServerClientV2`.""" + return foundry_sql_server.FoundrySqlServerClientV2(self) + @cached_property def build2(self) -> build2.Build2Client: """Returns :py:class:`foundry_dev_tools.clients.build2.Build2Client`.""" diff --git a/libs/foundry-dev-tools/src/foundry_dev_tools/errors/handling.py b/libs/foundry-dev-tools/src/foundry_dev_tools/errors/handling.py index 1287123..4c50277 100644 --- a/libs/foundry-dev-tools/src/foundry_dev_tools/errors/handling.py +++ b/libs/foundry-dev-tools/src/foundry_dev_tools/errors/handling.py @@ -47,9 +47,7 @@ ) from foundry_dev_tools.errors.meta import FoundryAPIError from foundry_dev_tools.errors.multipass import DuplicateGroupNameError -from foundry_dev_tools.errors.sql import ( - FoundrySqlQueryFailedError, -) +from foundry_dev_tools.errors.sql import FoundrySqlQueryFailedError, FurnaceSqlSqlParseError from foundry_dev_tools.utils.misc import decamelize LOGGER = logging.getLogger(__name__) @@ -59,6 +57,7 @@ "DataProxy:SchemaNotFound": DatasetHasNoSchemaError, "DataProxy:FallbackBranchesNotSpecifiedInQuery": BranchNotFoundError, "DataProxy:BadSqlQuery": FoundrySqlQueryFailedError, + "FurnaceSql:SqlParseError": FurnaceSqlSqlParseError, "DataProxy:DatasetNotFound": DatasetNotFoundError, "Catalog:DuplicateDatasetName": DatasetAlreadyExistsError, "Catalog:DatasetsNotFound": DatasetNotFoundError, diff --git a/libs/foundry-dev-tools/src/foundry_dev_tools/errors/sql.py b/libs/foundry-dev-tools/src/foundry_dev_tools/errors/sql.py index efe789a..567cd5d 100644 --- a/libs/foundry-dev-tools/src/foundry_dev_tools/errors/sql.py +++ b/libs/foundry-dev-tools/src/foundry_dev_tools/errors/sql.py @@ -20,6 +20,12 @@ def __init__(self, response: requests.Response): super().__init__(response=response, info=self.error_message) +class FurnaceSqlSqlParseError(FoundryAPIError): + """Exception is thrown when SQL Query is not valid.""" + + message = "Foundry SQL Query Parsing Failed." + + class FoundrySqlQueryClientTimedOutError(FoundryAPIError): """Exception is thrown when the Query execution time exceeded the client timeout value.""" diff --git a/tests/integration/clients/test_foundry_sql_server.py b/tests/integration/clients/test_foundry_sql_server.py index 7dd079d..2e03614 100644 --- a/tests/integration/clients/test_foundry_sql_server.py +++ b/tests/integration/clients/test_foundry_sql_server.py @@ -1,7 +1,11 @@ import pytest from foundry_dev_tools.errors.dataset import BranchNotFoundError, DatasetHasNoSchemaError, DatasetNotFoundError -from foundry_dev_tools.errors.sql import FoundrySqlQueryFailedError, FoundrySqlSerializationFormatNotImplementedError +from foundry_dev_tools.errors.sql import ( + FoundrySqlQueryFailedError, + FoundrySqlSerializationFormatNotImplementedError, + FurnaceSqlSqlParseError, +) from tests.integration.conftest import TEST_SINGLETON @@ -56,3 +60,108 @@ def test_legacy_fallback(mocker): TEST_SINGLETON.ctx.foundry_sql_server.query_foundry_sql(f"SELECT * FROM `{TEST_SINGLETON.iris_new.rid}`") query_foundry_sql_legacy_spy.assert_called() + + +# V2 Client Tests + + +def test_v2_smoke(): + """Test basic V2 client functionality with a simple query.""" + one_row_one_column = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT sepal_width FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 1", + application_id=TEST_SINGLETON.iris_new.rid, + ) + assert one_row_one_column.shape == (1, 1) + + one_row_one_column = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT sepal_width FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 1", + application_id=TEST_SINGLETON.iris_new.rid, + return_type="arrow", + ) + assert one_row_one_column.num_columns == 1 + assert one_row_one_column.num_rows == 1 + assert one_row_one_column.column_names == ["sepal_width"] + + +def test_v2_multiple_rows(): + """Test V2 client with multiple rows.""" + result = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT * FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 10", + application_id=TEST_SINGLETON.iris_new.rid, + ) + assert result.shape[0] == 10 + assert result.shape[1] == 5 # iris dataset has 5 columns + + +def test_v2_return_type_arrow(): + """Test V2 client with Arrow return type.""" + import pyarrow as pa + + result = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT * FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 5", + application_id=TEST_SINGLETON.iris_new.rid, + return_type="arrow", + ) + assert isinstance(result, pa.Table) + assert result.num_rows == 5 + + +def test_v2_return_type_raw_not_supported(): + """Test V2 client with raw return type.""" + with pytest.raises(ValueError, match="The following return_type is not supported: .+"): + schema, rows = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT sepal_width, sepal_length FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 3", + application_id=TEST_SINGLETON.iris_new.rid, + return_type="raw", + ) + + +def test_v2_aggregation_query(): + """Test V2 client with aggregation query.""" + result = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f""" + SELECT + COUNT(*) as total_count, + AVG(sepal_width) as avg_sepal_width + FROM `{TEST_SINGLETON.iris_new.rid}` + """, + application_id=TEST_SINGLETON.iris_new.rid, + ) + assert result.shape == (1, 2) + assert "total_count" in result.columns + assert "avg_sepal_width" in result.columns + + +def test_v2_query_failed(): + """Test V2 client with invalid SQL query.""" + with pytest.raises(FurnaceSqlSqlParseError): + TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT foo, bar, FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 100", + application_id=TEST_SINGLETON.iris_new.rid, + ) + + +def test_v2_disable_arrow_compression(): + """Test V2 client with arrow compression disabled.""" + result = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT * FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 5", + application_id=TEST_SINGLETON.iris_new.rid, + disable_arrow_compression=True, + ) + assert result.shape[0] == 5 + + +def test_v2_with_where_clause(): + """Test V2 client with WHERE clause.""" + result = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f""" + SELECT * FROM `{TEST_SINGLETON.iris_new.rid}` + WHERE is_setosa = 'setosa' + LIMIT 20 + """, + application_id=TEST_SINGLETON.iris_new.rid, + ) + assert result.shape[0] <= 20 + # Verify all returned rows have is_setosa = 'setosa' + if result.shape[0] > 0: + assert all(result["is_setosa"] == "setosa")