Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import time
import warnings
from typing import TYPE_CHECKING, Literal, overload
from typing import TYPE_CHECKING, Any, Literal, overload

from foundry_dev_tools.clients.api_client import APIClient
from foundry_dev_tools.errors.handling import ErrorHandlingConfig
Expand Down Expand Up @@ -278,3 +278,263 @@ def api_queries_results(
},
**kwargs,
)


class FoundrySqlServerClientV2(APIClient):
"""FoundrySqlServerClientV2 implements the newer foundry-sql-server API.

This client uses a different API flow compared to V1:
- Executes queries via POST to /api/ with applicationId and sql
- Polls POST to /api/status for query completion
- Retrieves results via POST to /api/stream with tickets
"""

api_name = "foundry-sql-server"

@overload
def query_foundry_sql(
self,
query: str,
application_id: str,
return_type: Literal["pandas"],
disable_arrow_compression: bool = ...,
timeout: int = ...,
) -> pd.core.frame.DataFrame: ...

@overload
def query_foundry_sql(
self,
query: str,
application_id: str,
return_type: Literal["spark"],
disable_arrow_compression: bool = ...,
timeout: int = ...,
) -> pyspark.sql.DataFrame: ...

@overload
def query_foundry_sql(
self,
query: str,
application_id: str,
return_type: Literal["arrow"],
disable_arrow_compression: bool = ...,
timeout: int = ...,
) -> pa.Table: ...

@overload
def query_foundry_sql(
self,
query: str,
application_id: str,
return_type: SQLReturnType = ...,
disable_arrow_compression: bool = ...,
timeout: int = ...,
) -> tuple[dict, list[list]] | pd.core.frame.DataFrame | pa.Table | pyspark.sql.DataFrame: ...

def query_foundry_sql(
self,
query: str,
return_type: SQLReturnType = "pandas",
disable_arrow_compression: bool = False,
application_id: str | None = None,
) -> tuple[dict, list[list]] | pd.core.frame.DataFrame | pa.Table | pyspark.sql.DataFrame:
"""Queries the Foundry SQL server using the V2 API.

Uses Arrow IPC to communicate with the Foundry SQL Server Endpoint.

Example:
df = client.query_foundry_sql(
query="SELECT * FROM `ri.foundry.main.dataset.abc` LIMIT 10",
application_id="ri.foundry.main.dataset.abc"
)

Args:
query: The SQL Query
return_type: See :py:class:foundry_dev_tools.foundry_api_client.SQLReturnType
disable_arrow_compression: Whether to disable Arrow compression
application_id: The application/dataset RID, defaults to foundry-dev-tools User-Agent

Returns:
:external+pandas:py:class:`~pandas.DataFrame` | :external+pyarrow:py:class:`~pyarrow.Table` | :external+spark:py:class:`~pyspark.sql.DataFrame`:

A pandas DataFrame, Spark DataFrame or pyarrow.Table with the result.

Raises:
FoundrySqlQueryFailedError: If the query fails
FoundrySqlQueryClientTimedOutError: If the query times out

""" # noqa: E501
# Execute the query
if not application_id:
application_id = self.context.client.headers["User-Agent"]
response_json = self.api_execute(
sql=query,
application_id=application_id,
disable_arrow_compression=disable_arrow_compression,
).json()

query_identifier = self._extract_query_identifier(response_json)

# Poll for completion
while response_json.get("type") != "success":
time.sleep(0.2)
response = self.api_status(query_identifier)
response_json = response.json()

if response_json.get("type") == "failed":
raise FoundrySqlQueryFailedError(response)

# Extract tickets from successful response
tickets = self._extract_tickets(response_json)

# Fetch Arrow data using tickets
arrow_stream_reader = self.read_stream_results_arrow(tickets)

if return_type == "pandas":
return arrow_stream_reader.read_pandas()

if return_type == "spark":
from foundry_dev_tools.utils.converter.foundry_spark import (
arrow_stream_to_spark_dataframe,
)

return arrow_stream_to_spark_dataframe(arrow_stream_reader)

if return_type == "arrow":
return arrow_stream_reader.read_all()

raise ValueError("The following return_type is not supported: " + return_type)

def _extract_query_identifier(self, response_json: dict[str, Any]) -> dict[str, Any]:
"""Extract query identifier from execute response.

Args:
response_json: Response JSON from execute API

Returns:
Query identifier dict

"""
if response_json.get("type") == "pending":
return response_json["pending"]["query"]
if response_json.get("type") == "success":
return response_json["success"]["query"]
msg = f"Unexpected response type: {response_json.get('type')}"
raise ValueError(msg)

def _extract_tickets(self, response_json: dict[str, Any]) -> list[str]:
"""Extract tickets from success response.

Args:
response_json: Success response JSON from status API

Returns:
List of tickets for fetching results

"""
if response_json.get("type") != "success":
msg = f"Expected success response, got: {response_json.get('type')}"

raise ValueError(msg)

chunks = response_json["success"]["result"]["interactive"]["chunks"]
return [chunk["ticket"] for chunk in chunks]

def read_stream_results_arrow(self, tickets: list[str]) -> pa.ipc.RecordBatchStreamReader:
"""Fetch query results using tickets and return Arrow stream reader.

Args:
tickets: List of tickets from status API success response

Returns:
Arrow RecordBatchStreamReader

"""
from foundry_dev_tools._optional.pyarrow import pa

response = self._api_stream_ticket(tickets)
response.raw.decode_content = True

return pa.ipc.RecordBatchStreamReader(response.raw)

def api_execute(
self,
sql: str,
application_id: str,
disable_arrow_compression: bool = False,
**kwargs,
) -> requests.Response:
"""Execute a SQL query via the V2 API.

Args:
sql: The SQL query to execute
application_id: The application/dataset RID
disable_arrow_compression: Whether to disable Arrow compression
**kwargs: gets passed to :py:meth:`APIClient.api_request`

Returns:
Response with query execution status

"""
return self.api_request(
"POST",
"", # Root endpoint /api/
json={
"applicationId": application_id,
"sql": sql,
"disableArrowCompression": disable_arrow_compression,
},
**kwargs,
)

def api_status(
self,
query_identifier: dict[str, Any],
**kwargs,
) -> requests.Response:
"""Get the status of a SQL query via the V2 API.

Args:
query_identifier: Query identifier dict (e.g., {"type": "interactive", "interactive": "query-id"})
**kwargs: gets passed to :py:meth:`APIClient.api_request`

Returns:
Response with query status

"""
return self.api_request(
"POST",
"status",
json={
"query": query_identifier,
},
**kwargs,
)

def _api_stream_ticket(
self,
tickets: list[str],
**kwargs,
) -> requests.Response:
"""Fetch query results using tickets via the V2 API.

Args:
tickets: List of tickets from status API success response
**kwargs: gets passed to :py:meth:`APIClient.api_request`

Returns:
Response with Arrow-encoded query results

"""
return self.api_request(
"POST",
"stream",
json={
"tickets": tickets,
},
headers={
"Accept": "application/octet-stream",
},
stream=True,
**kwargs,
)
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ def foundry_sql_server(self) -> foundry_sql_server.FoundrySqlServerClient:
"""Returns :py:class:`foundry_dev_tools.clients.foundry_sql_server.FoundrySqlServerClient`."""
return foundry_sql_server.FoundrySqlServerClient(self)

@cached_property
def foundry_sql_server_v2(self) -> foundry_sql_server.FoundrySqlServerClientV2:
"""Returns :py:class:`foundry_dev_tools.clients.foundry_sql_server.FoundrySqlServerClientV2`."""
return foundry_sql_server.FoundrySqlServerClientV2(self)

@cached_property
def build2(self) -> build2.Build2Client:
"""Returns :py:class:`foundry_dev_tools.clients.build2.Build2Client`."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@
)
from foundry_dev_tools.errors.meta import FoundryAPIError
from foundry_dev_tools.errors.multipass import DuplicateGroupNameError
from foundry_dev_tools.errors.sql import (
FoundrySqlQueryFailedError,
)
from foundry_dev_tools.errors.sql import FoundrySqlQueryFailedError, FurnaceSqlSqlParseError
from foundry_dev_tools.utils.misc import decamelize

LOGGER = logging.getLogger(__name__)
Expand All @@ -59,6 +57,7 @@
"DataProxy:SchemaNotFound": DatasetHasNoSchemaError,
"DataProxy:FallbackBranchesNotSpecifiedInQuery": BranchNotFoundError,
"DataProxy:BadSqlQuery": FoundrySqlQueryFailedError,
"FurnaceSql:SqlParseError": FurnaceSqlSqlParseError,
"DataProxy:DatasetNotFound": DatasetNotFoundError,
"Catalog:DuplicateDatasetName": DatasetAlreadyExistsError,
"Catalog:DatasetsNotFound": DatasetNotFoundError,
Expand Down
6 changes: 6 additions & 0 deletions libs/foundry-dev-tools/src/foundry_dev_tools/errors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ def __init__(self, response: requests.Response):
super().__init__(response=response, info=self.error_message)


class FurnaceSqlSqlParseError(FoundryAPIError):
"""Exception is thrown when SQL Query is not valid."""

message = "Foundry SQL Query Parsing Failed."


class FoundrySqlQueryClientTimedOutError(FoundryAPIError):
"""Exception is thrown when the Query execution time exceeded the client timeout value."""

Expand Down
Loading